refactor: cat ops (#100)

This commit is contained in:
Nathaniel Simard 2022-11-12 13:02:10 -05:00 committed by GitHub
parent ab39b8779b
commit da7a8e3f6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 180 additions and 218 deletions

View File

@ -1,159 +0,0 @@
use crate::graph::converter::Forward2BackwardGraphConverter;
use crate::graph::node::{BackwardNode, BackwardNodeRef, BackwardNodeState, ForwardNodeRef};
use crate::graph::ops::{
BackwardRecordedOps, BackwardRecordedOpsRef, ForwardRecordedOps, RecordedOpsParentRef,
};
use crate::tensor::backend::Backend;
use crate::tensor::{backend::autodiff::ADTensor, ops::*};
use std::convert::TryInto;
use std::sync::Arc;
#[derive(new, Debug)]
pub struct ForwardCatOps<const D: usize, B: Backend> {
nodes: Vec<ForwardNodeRef<B::TensorPrimitive<D>>>,
dim: usize,
}
#[derive(new, Debug)]
pub struct BackwardCatOps<const D: usize, B: Backend> {
nodes: Vec<BackwardNodeRef<B::TensorPrimitive<D>>>,
dim: usize,
}
impl<const D: usize, B: Backend> ForwardRecordedOps<B::TensorPrimitive<D>> for ForwardCatOps<D, B> {
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> BackwardRecordedOpsRef<B::TensorPrimitive<D>> {
Arc::new(BackwardCatOps::<D, B>::new(
self.nodes
.iter()
.map(|node| {
let ops: BackwardNode<B::TensorPrimitive<D>> =
BackwardNode::from_node(node, graph);
Arc::new(ops)
})
.collect(),
self.dim,
))
}
}
impl<const D: usize, B: Backend> BackwardRecordedOps<B::TensorPrimitive<D>>
for BackwardCatOps<D, B>
{
fn backward_step(&self, state: &BackwardNodeState<B::TensorPrimitive<D>>) {
let grad = state.grad();
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
self.nodes.iter().enumerate().for_each(|(i, node)| {
let mut indexes = indexes.clone();
indexes[self.dim] = i..i + 1;
node.state.update_grad(B::index(&grad, indexes));
});
}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
self.nodes
.iter()
.map(|node| {
let ops: RecordedOpsParentRef = node.clone();
ops
})
.collect()
}
}
impl<B: Backend, const D: usize> TensorOpsCat<B::Elem, D> for ADTensor<D, B> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self {
let nodes: Vec<_> = tensors.iter().map(|t| t.node.clone()).collect();
let order = nodes.iter().map(|node| node.order).max().unwrap() + 1;
let tensors_inner: Vec<B::TensorPrimitive<D>> =
tensors.into_iter().map(|a| a.tensor()).collect();
let tensors_inner_ref: Vec<&B::TensorPrimitive<D>> = tensors_inner.iter().collect();
let out = TensorOpsCat::cat(tensors_inner_ref, dim);
let shape = *B::shape(&out);
let state = crate::graph::node::ForwardNodeState::new(out);
let ops = ForwardCatOps::<D, B>::new(nodes, dim);
let ops = Arc::new(ops);
let node = crate::graph::node::ForwardNode::new(order, state, ops);
let node = std::sync::Arc::new(node);
ADTensor { node, shape }
}
}
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
#[test]
fn should_diff_cat() {
let data_1 = Data::<_, 2>::from([[2.0, -1.0], [5.0, 2.0]]);
let data_2 = Data::<_, 2>::from([[5.0, 4.0], [-1.0, 4.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(TestADTensor::from_data(
tensor_1.index([i..i + 1]).to_data(),
));
tensor_2_list.push(TestADTensor::from_data(
tensor_2.index([i..i + 1]).to_data(),
));
}
let tensor_1_cat = TestADTensor::cat(tensor_1_list.clone(), 0);
let tensor_2_cat = TestADTensor::cat(tensor_2_list.clone(), 0);
let tensor_3_cat = tensor_1_cat.matmul(&tensor_2_cat);
let grads_cat = tensor_3_cat.backward();
let grad_1_cat = tensor_1_cat.grad(&grads_cat).unwrap();
let grad_2_cat = tensor_2_cat.grad(&grads_cat).unwrap();
let grad_1_list_1 = tensor_1_list.get(0).unwrap().grad(&grads_cat).unwrap();
let grad_1_list_2 = tensor_1_list.get(1).unwrap().grad(&grads_cat).unwrap();
let grad_2_list_1 = tensor_2_list.get(0).unwrap().grad(&grads_cat).unwrap();
let grad_2_list_2 = tensor_2_list.get(1).unwrap().grad(&grads_cat).unwrap();
grad_1.to_data().assert_approx_eq(&grad_1_cat.to_data(), 3);
grad_2.to_data().assert_approx_eq(&grad_2_cat.to_data(), 3);
grad_1
.index([0..1])
.to_data()
.assert_approx_eq(&grad_1_list_1.to_data(), 3);
grad_1
.index([1..2])
.to_data()
.assert_approx_eq(&grad_1_list_2.to_data(), 3);
grad_2
.index([0..1])
.to_data()
.assert_approx_eq(&grad_2_list_1.to_data(), 3);
grad_2
.index([1..2])
.to_data()
.assert_approx_eq(&grad_2_list_2.to_data(), 3);
}
}

View File

@ -1,5 +1,4 @@
mod base;
mod cat;
mod creation;
mod module;
mod tensor;

View File

@ -1,5 +1,10 @@
use super::{binary_ops_wrapper, unary_ops_wrapper};
use crate::backend::autodiff::ops::unary_ops_wrapper_explicit;
use crate::graph::converter::Forward2BackwardGraphConverter;
use crate::graph::node::{BackwardNode, BackwardNodeRef, BackwardNodeState, ForwardNodeRef};
use crate::graph::ops::{
BackwardRecordedOps, BackwardRecordedOpsRef, ForwardRecordedOps, RecordedOpsParentRef,
};
use crate::tensor::ElementConversion;
use crate::{
backend::{
@ -11,6 +16,7 @@ use crate::{
Data, Shape, Tensor,
};
use std::ops::Range;
use std::sync::Arc;
impl<B: Backend, const D: usize> std::ops::Add<ADTensor<D, B>> for ADTensor<D, B> {
type Output = ADTensor<D, B>;
@ -1045,4 +1051,85 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn cat<const D: usize>(
tensors: &[<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>],
dim: usize,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
pub struct ForwardCatOps<const D: usize, B: Backend> {
nodes: Vec<ForwardNodeRef<B::TensorPrimitive<D>>>,
dim: usize,
}
#[derive(new, Debug)]
pub struct BackwardCatOps<const D: usize, B: Backend> {
nodes: Vec<BackwardNodeRef<B::TensorPrimitive<D>>>,
dim: usize,
}
impl<const D: usize, B: Backend> ForwardRecordedOps<B::TensorPrimitive<D>> for ForwardCatOps<D, B> {
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> BackwardRecordedOpsRef<B::TensorPrimitive<D>> {
Arc::new(BackwardCatOps::<D, B>::new(
self.nodes
.iter()
.map(|node| {
let ops: BackwardNode<B::TensorPrimitive<D>> =
BackwardNode::from_node(node, graph);
Arc::new(ops)
})
.collect(),
self.dim,
))
}
}
impl<const D: usize, B: Backend> BackwardRecordedOps<B::TensorPrimitive<D>>
for BackwardCatOps<D, B>
{
fn backward_step(&self, state: &BackwardNodeState<B::TensorPrimitive<D>>) {
let grad = state.grad();
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
self.nodes.iter().enumerate().for_each(|(i, node)| {
let mut indexes = indexes.clone();
indexes[self.dim] = i..i + 1;
node.state.update_grad(B::index(&grad, indexes));
});
}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
self.nodes
.iter()
.map(|node| {
let ops: RecordedOpsParentRef = node.clone();
ops
})
.collect()
}
}
let nodes: Vec<_> = tensors.iter().map(|t| t.node.clone()).collect();
let order = nodes.iter().map(|node| node.order).max().unwrap() + 1;
let tensors_inner: Vec<B::TensorPrimitive<D>> =
tensors.iter().map(|a| a.tensor()).collect();
let out = B::cat(&tensors_inner, dim);
let shape = *B::shape(&out);
let state = crate::graph::node::ForwardNodeState::new(out);
let ops = ForwardCatOps::<D, B>::new(nodes, dim);
let ops = Arc::new(ops);
let node = crate::graph::node::ForwardNode::new(order, state, ops);
let node = std::sync::Arc::new(node);
ADTensor { node, shape }
}
}

View File

@ -23,7 +23,6 @@ pub trait Backend:
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsCat<Self::Elem, D>
+ ReLU<Self::Elem, D>
+ Clone
+ Send

View File

@ -23,7 +23,7 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
[index..index + 1, 0..d_model],
));
}
let embedding = TensorOpsCat::cat(tensors.iter().collect(), 0);
let embedding = NdArrayBackend::cat(&tensors, 0);
NdArrayBackend::reshape(&embedding, Shape::new([batch_size, seq_length, d_model]))
}

View File

@ -1,20 +0,0 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
NdArrayElement,
};
use ndarray::{Axis, IxDyn};
impl<P: NdArrayElement, const D: usize> TensorOpsCat<P, D> for NdArrayTensor<P, D> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self {
let mut shape = tensors.get(0).unwrap().shape;
shape.dims[dim] = tensors.len();
let arrays: Vec<ndarray::ArrayView<P, IxDyn>> =
tensors.into_iter().map(|t| t.array.view()).collect();
let array = ndarray::concatenate(Axis(dim), &arrays)
.unwrap()
.into_shared();
Self { array, shape }
}
}

View File

@ -1,2 +1 @@
mod cat;
mod creation;

View File

@ -4,7 +4,7 @@ use crate::{
ops::TensorOps,
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
};
use ndarray::{Axis, Dim, SliceInfoElem};
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
use std::{cmp::Ordering, ops::Range};
macro_rules! keepdim {
@ -484,6 +484,19 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor { array, shape }
}
fn cat<const D: usize>(tensors: &[NdArrayTensor<E, D>], dim: usize) -> NdArrayTensor<E, D> {
let mut shape = tensors.get(0).unwrap().shape;
shape.dims[dim] = tensors.len();
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
tensors.iter().map(|t| t.array.view()).collect();
let array = ndarray::concatenate(Axis(dim), &arrays)
.unwrap()
.into_shared();
NdArrayTensor { array, shape }
}
}
fn to_slice_args<const D1: usize, const D2: usize>(

View File

@ -1,23 +0,0 @@
use crate::{
backend::tch::TchKind,
tensor::{backend::tch::TchTensor, ops::*, Shape},
TchElement,
};
impl<P: TchElement, const D: usize> TensorOpsCat<P, D> for TchTensor<P, D> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self {
let tensors: Vec<tch::Tensor> = tensors
.into_iter()
.map(|t| t.tensor.shallow_clone())
.collect();
let tensor = tch::Tensor::cat(&tensors, dim as i64);
let shape = Shape::from(tensor.size());
let kind = TchKind::new();
Self {
tensor,
shape,
kind,
}
}
}

View File

@ -1,2 +1 @@
mod cat;
mod creation;

View File

@ -382,6 +382,12 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn erf<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.erf())
}
fn cat<const D: usize>(tensors: &[TchTensor<E, D>], dim: usize) -> TchTensor<E, D> {
let tensors: Vec<tch::Tensor> = tensors.iter().map(|t| t.tensor.shallow_clone()).collect();
let tensor = tch::Tensor::cat(&tensors, dim as i64);
to_tensor(tensor)
}
}
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {

View File

@ -2,7 +2,6 @@ use crate::graph::grad::Gradients;
use crate::tensor::backend::ADBackend;
use crate::tensor::backend::Backend;
use crate::tensor::ops::activation::*;
use crate::tensor::ops::*;
use crate::tensor::stats;
use crate::tensor::ElementConversion;
use crate::tensor::{Data, Distribution, Shape};
@ -506,11 +505,10 @@ where
///
/// If all tensors don't have the same shape.
pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
let tensors: Vec<B::TensorPrimitive<D>> = tensors.into_iter().map(|a| a.value).collect();
let tensors: Vec<&B::TensorPrimitive<D>> = tensors.iter().collect();
let value = B::TensorPrimitive::cat(tensors, dim);
Self::new(value)
Self::new(B::cat(
&tensors.into_iter().map(|t| t.value).collect::<Vec<_>>(),
dim,
))
}
/// Detach the current tensor from the autodiff graph.

View File

@ -196,10 +196,7 @@ pub trait TensorOps<B: Backend> {
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
fn erf<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsCat<E, const D: usize> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self;
fn cat<const D: usize>(tensors: &[B::TensorPrimitive<D>], dim: usize) -> B::TensorPrimitive<D>;
}
pub trait Zeros<T> {

View File

@ -0,0 +1,66 @@
use crate::tensor::TestADTensor;
use burn_tensor::Data;
#[test]
fn should_diff_cat() {
let data_1 = Data::<_, 2>::from([[2.0, -1.0], [5.0, 2.0]]);
let data_2 = Data::<_, 2>::from([[5.0, 4.0], [-1.0, 4.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(TestADTensor::from_data(
tensor_1.index([i..i + 1]).to_data(),
));
tensor_2_list.push(TestADTensor::from_data(
tensor_2.index([i..i + 1]).to_data(),
));
}
let tensor_1_cat = TestADTensor::cat(tensor_1_list.clone(), 0);
let tensor_2_cat = TestADTensor::cat(tensor_2_list.clone(), 0);
let tensor_3_cat = tensor_1_cat.matmul(&tensor_2_cat);
let grads_cat = tensor_3_cat.backward();
let grad_1_cat = tensor_1_cat.grad(&grads_cat).unwrap();
let grad_2_cat = tensor_2_cat.grad(&grads_cat).unwrap();
let grad_1_list_1 = tensor_1_list.get(0).unwrap().grad(&grads_cat).unwrap();
let grad_1_list_2 = tensor_1_list.get(1).unwrap().grad(&grads_cat).unwrap();
let grad_2_list_1 = tensor_2_list.get(0).unwrap().grad(&grads_cat).unwrap();
let grad_2_list_2 = tensor_2_list.get(1).unwrap().grad(&grads_cat).unwrap();
grad_1.to_data().assert_approx_eq(&grad_1_cat.to_data(), 3);
grad_2.to_data().assert_approx_eq(&grad_2_cat.to_data(), 3);
grad_1
.index([0..1])
.to_data()
.assert_approx_eq(&grad_1_list_1.to_data(), 3);
grad_1
.index([1..2])
.to_data()
.assert_approx_eq(&grad_1_list_2.to_data(), 3);
grad_2
.index([0..1])
.to_data()
.assert_approx_eq(&grad_2_list_1.to_data(), 3);
grad_2
.index([1..2])
.to_data()
.assert_approx_eq(&grad_2_list_2.to_data(), 3);
}

View File

@ -1,5 +1,6 @@
mod add;
mod aggregation;
mod cat;
mod cross_entropy;
mod div;
mod erf;