mirror of https://github.com/tracel-ai/burn.git
Feat/candle/module ops (#725)
This commit is contained in:
parent
aafceeffa0
commit
760c9e1d8e
|
@ -28,15 +28,15 @@ mod tests {
|
|||
type TestADTensor<const D: usize, K> = burn_tensor::Tensor<TestADBackend, D, K>;
|
||||
|
||||
// test activation
|
||||
// burn_tensor::testgen_gelu!();
|
||||
// burn_tensor::testgen_relu!();
|
||||
// burn_tensor::testgen_softmax!();
|
||||
// burn_tensor::testgen_sigmoid!();
|
||||
// burn_tensor::testgen_silu!();
|
||||
burn_tensor::testgen_gelu!();
|
||||
burn_tensor::testgen_relu!();
|
||||
burn_tensor::testgen_softmax!();
|
||||
burn_tensor::testgen_sigmoid!();
|
||||
burn_tensor::testgen_silu!();
|
||||
|
||||
// test module
|
||||
// burn_tensor::testgen_module_forward!();
|
||||
// burn_tensor::testgen_module_conv1d!();
|
||||
burn_tensor::testgen_module_forward!();
|
||||
burn_tensor::testgen_module_conv1d!();
|
||||
// burn_tensor::testgen_module_conv2d!();
|
||||
// burn_tensor::testgen_module_conv_transpose1d!();
|
||||
// burn_tensor::testgen_module_conv_transpose2d!();
|
||||
|
@ -72,7 +72,7 @@ mod tests {
|
|||
burn_tensor::testgen_maxmin!();
|
||||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_neg!();
|
||||
// burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_random!();
|
||||
// burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
|
@ -87,14 +87,15 @@ mod tests {
|
|||
burn_tensor::testgen_transpose!();
|
||||
|
||||
// test stats
|
||||
// burn_tensor::testgen_stats!();
|
||||
burn_tensor::testgen_var!();
|
||||
burn_tensor::testgen_display!();
|
||||
|
||||
// Behavior
|
||||
// burn_autodiff::testgen_ad_broadcast!();
|
||||
|
||||
// Activation
|
||||
// burn_autodiff::testgen_ad_relu!();
|
||||
// burn_autodiff::testgen_ad_gelu!();
|
||||
burn_autodiff::testgen_ad_relu!();
|
||||
burn_autodiff::testgen_ad_gelu!();
|
||||
|
||||
// Modules
|
||||
// burn_autodiff::testgen_ad_conv1d!();
|
||||
|
@ -107,36 +108,36 @@ mod tests {
|
|||
// burn_autodiff::testgen_ad_avg_pool2d!();
|
||||
// burn_autodiff::testgen_ad_adaptive_avg_pool1d!();
|
||||
// burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
|
||||
// burn_autodiff::testgen_module_backward!();
|
||||
burn_autodiff::testgen_module_backward!();
|
||||
|
||||
// Tensor
|
||||
// burn_autodiff::testgen_ad_complex!();
|
||||
// burn_autodiff::testgen_ad_multithread!();
|
||||
// burn_autodiff::testgen_ad_add!();
|
||||
// burn_autodiff::testgen_ad_aggregation!();
|
||||
// burn_autodiff::testgen_ad_maxmin!();
|
||||
burn_autodiff::testgen_ad_complex!();
|
||||
burn_autodiff::testgen_ad_multithread!();
|
||||
burn_autodiff::testgen_ad_add!();
|
||||
burn_autodiff::testgen_ad_aggregation!();
|
||||
burn_autodiff::testgen_ad_maxmin!();
|
||||
// burn_autodiff::testgen_ad_cat!();
|
||||
// burn_autodiff::testgen_ad_cos!();
|
||||
// burn_autodiff::testgen_ad_cross_entropy_loss!();
|
||||
// burn_autodiff::testgen_ad_div!();
|
||||
burn_autodiff::testgen_ad_cos!();
|
||||
burn_autodiff::testgen_ad_cross_entropy_loss!();
|
||||
burn_autodiff::testgen_ad_div!();
|
||||
// burn_autodiff::testgen_ad_erf!();
|
||||
// burn_autodiff::testgen_ad_exp!();
|
||||
burn_autodiff::testgen_ad_exp!();
|
||||
// burn_autodiff::testgen_ad_slice!();
|
||||
// burn_autodiff::testgen_ad_gather_scatter!();
|
||||
// burn_autodiff::testgen_ad_select!();
|
||||
// burn_autodiff::testgen_ad_log!();
|
||||
// burn_autodiff::testgen_ad_log1p!();
|
||||
// burn_autodiff::testgen_ad_mask!();
|
||||
// burn_autodiff::testgen_ad_matmul!();
|
||||
// burn_autodiff::testgen_ad_mul!();
|
||||
// burn_autodiff::testgen_ad_neg!();
|
||||
// burn_autodiff::testgen_ad_powf!();
|
||||
// burn_autodiff::testgen_ad_reshape!();
|
||||
// burn_autodiff::testgen_ad_sin!();
|
||||
// burn_autodiff::testgen_ad_softmax!();
|
||||
// burn_autodiff::testgen_ad_sqrt!();
|
||||
// burn_autodiff::testgen_ad_abs!();
|
||||
// burn_autodiff::testgen_ad_sub!();
|
||||
// burn_autodiff::testgen_ad_tanh!();
|
||||
// burn_autodiff::testgen_ad_transpose!();
|
||||
burn_autodiff::testgen_ad_gather_scatter!();
|
||||
burn_autodiff::testgen_ad_select!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
burn_autodiff::testgen_ad_log1p!();
|
||||
burn_autodiff::testgen_ad_mask!();
|
||||
burn_autodiff::testgen_ad_matmul!();
|
||||
burn_autodiff::testgen_ad_mul!();
|
||||
burn_autodiff::testgen_ad_neg!();
|
||||
burn_autodiff::testgen_ad_powf!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_sin!();
|
||||
burn_autodiff::testgen_ad_softmax!();
|
||||
burn_autodiff::testgen_ad_sqrt!();
|
||||
burn_autodiff::testgen_ad_abs!();
|
||||
burn_autodiff::testgen_ad_sub!();
|
||||
burn_autodiff::testgen_ad_tanh!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
}
|
||||
|
|
|
@ -2,10 +2,19 @@ use burn_tensor::ops::ActivationOps;
|
|||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
CandleBackend,
|
||||
tensor, CandleBackend, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::FloatTensor;
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<CandleBackend<F, I>>
|
||||
for CandleBackend<F, I>
|
||||
{
|
||||
fn gelu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
CandleTensor::new(tensor.tensor.gelu().unwrap())
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
CandleTensor::new(tensor.tensor.relu().unwrap())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,6 +98,5 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
|
|||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: CandleTensor<E, D1>,
|
||||
) -> CandleTensor<E, D1> {
|
||||
// TODO: not trivial, because no view_ like in torch
|
||||
todo!()
|
||||
panic!("slice_assign not supported by Candle")
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
use burn_tensor::{
|
||||
ops::{ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps},
|
||||
Shape,
|
||||
};
|
||||
use candle_core::ToUsize2;
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
CandleBackend,
|
||||
ops::base::reshape,
|
||||
CandleBackend, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::{FloatTensor, IntTensor};
|
||||
|
@ -12,13 +15,76 @@ use super::base::{FloatTensor, IntTensor};
|
|||
impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||
for CandleBackend<F, I>
|
||||
{
|
||||
fn conv1d(
|
||||
x: FloatTensor<Self, 3>,
|
||||
weight: FloatTensor<Self, 3>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<Self, 3> {
|
||||
let conv = x
|
||||
.tensor
|
||||
.conv1d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
options.groups,
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv
|
||||
.broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
|
||||
.unwrap(),
|
||||
None => conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
assert!(
|
||||
options.dilation[0] == options.dilation[1]
|
||||
&& options.padding[0] == options.padding[1]
|
||||
&& options.stride[0] == options.stride[1],
|
||||
"Candle does not support per dimension options in convolutions"
|
||||
);
|
||||
let conv = x
|
||||
.tensor
|
||||
.conv2d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
options.groups,
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv
|
||||
.broadcast_add(
|
||||
&bias
|
||||
.tensor
|
||||
.unsqueeze(0)
|
||||
.unwrap()
|
||||
.unsqueeze(2)
|
||||
.unwrap()
|
||||
.unsqueeze(3)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap(),
|
||||
None => conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: FloatTensor<Self, 3>,
|
||||
weight: FloatTensor<Self, 3>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self, 3> {
|
||||
panic!("Candle does not support conv_transpose1d")
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
@ -27,7 +93,42 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
assert!(
|
||||
options.dilation[0] == options.dilation[1]
|
||||
&& options.padding[0] == options.padding[1]
|
||||
&& options.padding_out[0] == options.padding_out[1]
|
||||
&& options.stride[0] == options.stride[1],
|
||||
"Candle does not support per dimension options in transposed convolutions"
|
||||
);
|
||||
assert!(
|
||||
options.groups == 1,
|
||||
"Candle does not support groups in transposed convolutions"
|
||||
);
|
||||
let conv_transpose = x
|
||||
.tensor
|
||||
.conv_transpose2d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv_transpose
|
||||
.broadcast_add(
|
||||
&bias
|
||||
.tensor
|
||||
.unsqueeze(0)
|
||||
.unwrap()
|
||||
.unsqueeze(2)
|
||||
.unwrap()
|
||||
.unsqueeze(3)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap(),
|
||||
None => conv_transpose,
|
||||
})
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
|
@ -37,7 +138,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
assert!(
|
||||
padding[0] == 0 && padding[1] == 0,
|
||||
"Candle does not support padding in pooling"
|
||||
);
|
||||
assert!(
|
||||
count_include_pad,
|
||||
"Candle does not support excluding pad count in pooling"
|
||||
);
|
||||
CandleTensor::new(
|
||||
x.tensor
|
||||
.avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
|
@ -48,7 +161,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
panic!("avg_pool2d_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
|
@ -58,7 +171,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
assert!(
|
||||
padding[0] == 0 && padding[1] == 0,
|
||||
"Candle does not support padding in pooling"
|
||||
);
|
||||
assert!(
|
||||
dilation[0] == 1 && dilation[1] == 1,
|
||||
"Candle does not support dilation in pooling"
|
||||
);
|
||||
CandleTensor::new(
|
||||
x.tensor
|
||||
.max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
|
@ -68,7 +193,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<CandleBackend<F, I>> {
|
||||
todo!()
|
||||
panic!("max_pool2d_with_indices is not supported by Candle")
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
|
@ -80,20 +205,20 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
output_grad: FloatTensor<Self, 4>,
|
||||
indices: IntTensor<Self, 4>,
|
||||
) -> MaxPool2dBackward<CandleBackend<F, I>> {
|
||||
todo!()
|
||||
panic!("max_pool2d_with_indices_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
output_size: [usize; 2],
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
panic!("adaptive_avg_pool2 is not supported by Candle")
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self, 4>,
|
||||
grad: FloatTensor<Self, 4>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
panic!("adaptive_avg_pool2d_backward is not supported by Candle")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -369,7 +369,12 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
|
|||
}
|
||||
|
||||
fn powf<const D: usize>(tensor: FloatTensor<Self, D>, value: f32) -> FloatTensor<Self, D> {
|
||||
panic!("powf not supported by Candle")
|
||||
CandleTensor::new(
|
||||
(tensor.tensor.log().unwrap() * value.elem::<f64>())
|
||||
.unwrap()
|
||||
.exp()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
|
|
|
@ -5,12 +5,14 @@ impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||
}
|
||||
|
||||
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.gelu_("none"),
|
||||
|tensor| tensor.gelu("none"),
|
||||
)
|
||||
}
|
||||
|
||||
fn gelu_backward<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
grad: TchTensor<E, D>,
|
||||
|
|
|
@ -68,6 +68,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_transpose!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_stats!();
|
||||
burn_tensor::testgen_var!();
|
||||
burn_tensor::testgen_display!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_max_pool1d_simple() {
|
||||
let kernel_size = 3;
|
||||
let padding = 1;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
|
@ -18,8 +18,8 @@ mod tests {
|
|||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||
]]);
|
||||
let y = TestTensor::from_floats([[
|
||||
[0.9861, 0.9861, 0.5474, 0.4477, 0.8221, 0.8221],
|
||||
[0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689],
|
||||
[0.9861, 0.5474, 0.4477, 0.8221],
|
||||
[0.949, 0.949, 0.949, 0.789],
|
||||
]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
||||
|
@ -81,13 +81,13 @@ mod tests {
|
|||
#[test]
|
||||
fn test_max_pool1d_with_indices() {
|
||||
let kernel_size = 2;
|
||||
let padding = 1;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
|
||||
let indices = Data::<IntElem, 3>::from([[[0, 1, 1, 3, 3]]]);
|
||||
let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]);
|
||||
let indices = Data::<IntElem, 3>::from([[[1, 1, 3]]]);
|
||||
let y = TestTensor::from_floats([[[0.6386, 0.6386, 0.5742]]]);
|
||||
|
||||
let (output, output_indices) =
|
||||
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#[burn_tensor_testgen::testgen(stats)]
|
||||
#[burn_tensor_testgen::testgen(display)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
@ -7,17 +7,6 @@ mod tests {
|
|||
type FloatElem = <TestBackend as Backend>::FloatElem;
|
||||
type IntElem = <TestBackend as Backend>::IntElem;
|
||||
|
||||
#[test]
|
||||
fn test_var() {
|
||||
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.var(1).into_data();
|
||||
|
||||
let data_expected = Data::from([[2.4892], [15.3333]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_2d_int_tensor() {
|
||||
let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
|
|
@ -1 +1,2 @@
|
|||
mod basic;
|
||||
mod display;
|
||||
mod var;
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#[burn_tensor_testgen::testgen(var)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
type FloatElem = <TestBackend as Backend>::FloatElem;
|
||||
type IntElem = <TestBackend as Backend>::IntElem;
|
||||
|
||||
#[test]
|
||||
fn test_var() {
|
||||
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.var(1).into_data();
|
||||
|
||||
let data_expected = Data::from([[2.4892], [15.3333]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue