mirror of https://github.com/tracel-ai/burn.git
refactor: transpose (#71)
This commit is contained in:
parent
ad23898d23
commit
e6541298b9
|
@ -15,7 +15,6 @@ mod pow;
|
|||
mod precision;
|
||||
mod reshape;
|
||||
mod tensor;
|
||||
mod transpose;
|
||||
|
||||
mod macros;
|
||||
pub(crate) use base::*;
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
Backend,
|
||||
},
|
||||
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
|
||||
ops::{Ones, TensorOps, TensorOpsTranspose},
|
||||
ops::{Ones, TensorOps},
|
||||
Data, Shape,
|
||||
};
|
||||
|
||||
|
@ -399,7 +399,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let out_grad = state.output.grad();
|
||||
let rhs = state.right.value().transpose();
|
||||
let rhs = B::transpose(&state.right.value());
|
||||
B::matmul(&out_grad, &rhs)
|
||||
}
|
||||
|
||||
|
@ -412,7 +412,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let out_grad = state.output.grad();
|
||||
let lhs = state.left.value().transpose();
|
||||
let lhs = B::transpose(&state.left.value());
|
||||
B::matmul(&lhs, &out_grad)
|
||||
}
|
||||
}
|
||||
|
@ -447,4 +447,33 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(new, Debug)]
|
||||
struct SwapDimsBackward<B: Backend, const D: usize> {
|
||||
_b: B,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for SwapDimsBackward<B, D>
|
||||
{
|
||||
fn partial(
|
||||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
B::swap_dims(&state.output.grad(), self.dim2, self.dim1)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::swap_dims(tensor.tensor_ref(), dim1, dim2);
|
||||
let ops = SwapDimsBackward::<B, D>::new(B::default(), dim1, dim2);
|
||||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
use crate::graph::ops::{UnaryOps, UnaryOpsNodeState};
|
||||
use crate::tensor::backend::autodiff::ADTensor;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::*;
|
||||
use crate::{execute_ops, register_ops};
|
||||
|
||||
register_ops!(
|
||||
ops UnaryOps,
|
||||
name ADTensorTransposeOps,
|
||||
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
|
||||
state.output.grad().transpose()
|
||||
},
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DimState {
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
}
|
||||
|
||||
register_ops!(
|
||||
ops UnaryOps,
|
||||
name ADTensorSwapDimOps state DimState,
|
||||
partial |dims: &DimState, state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
|
||||
state.output.grad().swap_dims(dims.dim2, dims.dim1)
|
||||
},
|
||||
);
|
||||
|
||||
impl<B: Backend, const D: usize> TensorOpsTranspose<B::Elem, D> for ADTensor<D, B> {
|
||||
fn transpose(&self) -> Self {
|
||||
execute_ops!(
|
||||
input self.node.clone(),
|
||||
out TensorOpsTranspose::transpose(&self.tensor()),
|
||||
ops ADTensorTransposeOps::<B, D>::new(),
|
||||
)
|
||||
}
|
||||
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
|
||||
execute_ops!(
|
||||
input self.node.clone(),
|
||||
out TensorOpsTranspose::swap_dims(&self.tensor(), dim1, dim2),
|
||||
ops ADTensorSwapDimOps::<B, D>::new(DimState { dim1, dim2 }),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_transpose() {
|
||||
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = Data::<f64, 2>::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.transpose());
|
||||
let tensor_4 = tensor_3.transpose();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_swap_dims() {
|
||||
let data_1 = Data::<f64, 3>::from([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]);
|
||||
let data_2 = Data::<f64, 3>::from([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.swap_dims(0, 2));
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2.swap_dims(1, 2));
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.to_data(),
|
||||
Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.to_data(),
|
||||
Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -22,7 +22,6 @@ pub trait Backend:
|
|||
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
|
||||
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
|
||||
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
|
||||
+ TensorOpsTranspose<Self::Elem, D>
|
||||
+ TensorOpsDetach<Self::Elem, D>
|
||||
+ Zeros<Self::TensorPrimitive<D>>
|
||||
+ Ones<Self::TensorPrimitive<D>>
|
||||
|
|
|
@ -12,4 +12,3 @@ mod mask;
|
|||
mod pow;
|
||||
mod precision;
|
||||
mod reshape;
|
||||
mod transpose;
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
use crate::{
|
||||
tensor::{backend::ndarray::NdArrayTensor, ops::*},
|
||||
NdArrayElement,
|
||||
};
|
||||
|
||||
impl<P, const D: usize> TensorOpsTranspose<P, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Default + Clone + std::fmt::Debug + NdArrayElement,
|
||||
{
|
||||
fn transpose(&self) -> Self {
|
||||
self.swap_dims(D - 2, D - 1)
|
||||
}
|
||||
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
|
||||
let mut shape = self.shape;
|
||||
let dim1_new = shape.dims[dim2];
|
||||
let dim2_new = shape.dims[dim1];
|
||||
|
||||
shape.dims[dim1] = dim1_new;
|
||||
shape.dims[dim2] = dim2_new;
|
||||
|
||||
let mut array = self.array.clone();
|
||||
array.swap_axes(dim1, dim2);
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
}
|
|
@ -178,4 +178,21 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
Self::mul_scalar(tensor, &(-1f32).to_elem::<E>())
|
||||
}
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: &NdArrayTensor<E, D>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let mut shape = tensor.shape;
|
||||
let dim1_new = shape.dims[dim2];
|
||||
let dim2_new = shape.dims[dim1];
|
||||
|
||||
shape.dims[dim1] = dim1_new;
|
||||
shape.dims[dim2] = dim2_new;
|
||||
|
||||
let mut array = tensor.array.clone();
|
||||
array.swap_axes(dim1, dim2);
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,4 +12,3 @@ mod mask;
|
|||
mod pow;
|
||||
mod precision;
|
||||
mod reshape;
|
||||
mod transpose;
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
use crate::tensor::{backend::tch::TchTensor, ops::*, Shape};
|
||||
|
||||
impl<P: tch::kind::Element, const D: usize> TensorOpsTranspose<P, D> for TchTensor<P, D> {
|
||||
fn transpose(&self) -> Self {
|
||||
let tensor = self.tensor.transpose(-2, -1);
|
||||
let kind = self.kind.clone();
|
||||
let shape = Shape::from(tensor.size());
|
||||
|
||||
Self {
|
||||
kind,
|
||||
tensor,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
|
||||
let tensor = self.tensor.transpose(dim1 as i64, dim2 as i64);
|
||||
let kind = self.kind.clone();
|
||||
let shape = Shape::from(tensor.size());
|
||||
|
||||
Self {
|
||||
kind,
|
||||
tensor,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -127,6 +127,15 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn neg<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
Self::mul_scalar(tensor, &(-1f32).to_elem::<E>())
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: &TchTensor<E, D>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> TchTensor<E, D> {
|
||||
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
|
||||
to_tensor(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {
|
||||
|
|
|
@ -197,7 +197,7 @@ where
|
|||
///
|
||||
/// If the tensor is of 1 dimension or less.
|
||||
pub fn transpose(&self) -> Self {
|
||||
Self::new(self.value.transpose())
|
||||
Self::new(B::transpose(&self.value))
|
||||
}
|
||||
|
||||
/// Swap two dimensions.
|
||||
|
@ -206,7 +206,7 @@ where
|
|||
///
|
||||
/// If the dimensions exceed the shape of than the tensor.
|
||||
pub fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
|
||||
Self::new(self.value.swap_dims(dim1, dim2))
|
||||
Self::new(B::swap_dims(&self.value, dim1, dim2))
|
||||
}
|
||||
|
||||
/// Applies the matrix multiplication operation.
|
||||
|
|
|
@ -103,11 +103,14 @@ pub trait TensorOps<B: Backend> {
|
|||
rhs: &B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn neg<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsTranspose<E, const D: usize> {
|
||||
fn transpose(&self) -> Self;
|
||||
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self;
|
||||
fn transpose<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
Self::swap_dims(tensor, D - 2, D - 1)
|
||||
}
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsReshape<B: Backend, const D: usize> {
|
||||
|
|
|
@ -7,3 +7,4 @@ mod mul;
|
|||
mod neg;
|
||||
mod softmax;
|
||||
mod sub;
|
||||
mod transpose;
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
use crate::tensor::TestADTensor;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_transpose() {
|
||||
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.transpose());
|
||||
let tensor_4 = tensor_3.transpose();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_swap_dims() {
|
||||
let data_1 = Data::<f32, 3>::from([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]);
|
||||
let data_2 = Data::<f32, 3>::from([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.swap_dims(0, 2));
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2.swap_dims(1, 2));
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.to_data(),
|
||||
Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.to_data(),
|
||||
Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]])
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue