mirror of https://github.com/tracel-ai/burn.git
Feat/candle/module ops (#725)
This commit is contained in:
parent
aafceeffa0
commit
760c9e1d8e
|
@ -28,15 +28,15 @@ mod tests {
|
||||||
type TestADTensor<const D: usize, K> = burn_tensor::Tensor<TestADBackend, D, K>;
|
type TestADTensor<const D: usize, K> = burn_tensor::Tensor<TestADBackend, D, K>;
|
||||||
|
|
||||||
// test activation
|
// test activation
|
||||||
// burn_tensor::testgen_gelu!();
|
burn_tensor::testgen_gelu!();
|
||||||
// burn_tensor::testgen_relu!();
|
burn_tensor::testgen_relu!();
|
||||||
// burn_tensor::testgen_softmax!();
|
burn_tensor::testgen_softmax!();
|
||||||
// burn_tensor::testgen_sigmoid!();
|
burn_tensor::testgen_sigmoid!();
|
||||||
// burn_tensor::testgen_silu!();
|
burn_tensor::testgen_silu!();
|
||||||
|
|
||||||
// test module
|
// test module
|
||||||
// burn_tensor::testgen_module_forward!();
|
burn_tensor::testgen_module_forward!();
|
||||||
// burn_tensor::testgen_module_conv1d!();
|
burn_tensor::testgen_module_conv1d!();
|
||||||
// burn_tensor::testgen_module_conv2d!();
|
// burn_tensor::testgen_module_conv2d!();
|
||||||
// burn_tensor::testgen_module_conv_transpose1d!();
|
// burn_tensor::testgen_module_conv_transpose1d!();
|
||||||
// burn_tensor::testgen_module_conv_transpose2d!();
|
// burn_tensor::testgen_module_conv_transpose2d!();
|
||||||
|
@ -72,7 +72,7 @@ mod tests {
|
||||||
burn_tensor::testgen_maxmin!();
|
burn_tensor::testgen_maxmin!();
|
||||||
burn_tensor::testgen_mul!();
|
burn_tensor::testgen_mul!();
|
||||||
burn_tensor::testgen_neg!();
|
burn_tensor::testgen_neg!();
|
||||||
// burn_tensor::testgen_powf!();
|
burn_tensor::testgen_powf!();
|
||||||
burn_tensor::testgen_random!();
|
burn_tensor::testgen_random!();
|
||||||
// burn_tensor::testgen_repeat!();
|
// burn_tensor::testgen_repeat!();
|
||||||
burn_tensor::testgen_reshape!();
|
burn_tensor::testgen_reshape!();
|
||||||
|
@ -87,14 +87,15 @@ mod tests {
|
||||||
burn_tensor::testgen_transpose!();
|
burn_tensor::testgen_transpose!();
|
||||||
|
|
||||||
// test stats
|
// test stats
|
||||||
// burn_tensor::testgen_stats!();
|
burn_tensor::testgen_var!();
|
||||||
|
burn_tensor::testgen_display!();
|
||||||
|
|
||||||
// Behavior
|
// Behavior
|
||||||
// burn_autodiff::testgen_ad_broadcast!();
|
// burn_autodiff::testgen_ad_broadcast!();
|
||||||
|
|
||||||
// Activation
|
// Activation
|
||||||
// burn_autodiff::testgen_ad_relu!();
|
burn_autodiff::testgen_ad_relu!();
|
||||||
// burn_autodiff::testgen_ad_gelu!();
|
burn_autodiff::testgen_ad_gelu!();
|
||||||
|
|
||||||
// Modules
|
// Modules
|
||||||
// burn_autodiff::testgen_ad_conv1d!();
|
// burn_autodiff::testgen_ad_conv1d!();
|
||||||
|
@ -107,36 +108,36 @@ mod tests {
|
||||||
// burn_autodiff::testgen_ad_avg_pool2d!();
|
// burn_autodiff::testgen_ad_avg_pool2d!();
|
||||||
// burn_autodiff::testgen_ad_adaptive_avg_pool1d!();
|
// burn_autodiff::testgen_ad_adaptive_avg_pool1d!();
|
||||||
// burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
|
// burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
|
||||||
// burn_autodiff::testgen_module_backward!();
|
burn_autodiff::testgen_module_backward!();
|
||||||
|
|
||||||
// Tensor
|
// Tensor
|
||||||
// burn_autodiff::testgen_ad_complex!();
|
burn_autodiff::testgen_ad_complex!();
|
||||||
// burn_autodiff::testgen_ad_multithread!();
|
burn_autodiff::testgen_ad_multithread!();
|
||||||
// burn_autodiff::testgen_ad_add!();
|
burn_autodiff::testgen_ad_add!();
|
||||||
// burn_autodiff::testgen_ad_aggregation!();
|
burn_autodiff::testgen_ad_aggregation!();
|
||||||
// burn_autodiff::testgen_ad_maxmin!();
|
burn_autodiff::testgen_ad_maxmin!();
|
||||||
// burn_autodiff::testgen_ad_cat!();
|
// burn_autodiff::testgen_ad_cat!();
|
||||||
// burn_autodiff::testgen_ad_cos!();
|
burn_autodiff::testgen_ad_cos!();
|
||||||
// burn_autodiff::testgen_ad_cross_entropy_loss!();
|
burn_autodiff::testgen_ad_cross_entropy_loss!();
|
||||||
// burn_autodiff::testgen_ad_div!();
|
burn_autodiff::testgen_ad_div!();
|
||||||
// burn_autodiff::testgen_ad_erf!();
|
// burn_autodiff::testgen_ad_erf!();
|
||||||
// burn_autodiff::testgen_ad_exp!();
|
burn_autodiff::testgen_ad_exp!();
|
||||||
// burn_autodiff::testgen_ad_slice!();
|
// burn_autodiff::testgen_ad_slice!();
|
||||||
// burn_autodiff::testgen_ad_gather_scatter!();
|
burn_autodiff::testgen_ad_gather_scatter!();
|
||||||
// burn_autodiff::testgen_ad_select!();
|
burn_autodiff::testgen_ad_select!();
|
||||||
// burn_autodiff::testgen_ad_log!();
|
burn_autodiff::testgen_ad_log!();
|
||||||
// burn_autodiff::testgen_ad_log1p!();
|
burn_autodiff::testgen_ad_log1p!();
|
||||||
// burn_autodiff::testgen_ad_mask!();
|
burn_autodiff::testgen_ad_mask!();
|
||||||
// burn_autodiff::testgen_ad_matmul!();
|
burn_autodiff::testgen_ad_matmul!();
|
||||||
// burn_autodiff::testgen_ad_mul!();
|
burn_autodiff::testgen_ad_mul!();
|
||||||
// burn_autodiff::testgen_ad_neg!();
|
burn_autodiff::testgen_ad_neg!();
|
||||||
// burn_autodiff::testgen_ad_powf!();
|
burn_autodiff::testgen_ad_powf!();
|
||||||
// burn_autodiff::testgen_ad_reshape!();
|
burn_autodiff::testgen_ad_reshape!();
|
||||||
// burn_autodiff::testgen_ad_sin!();
|
burn_autodiff::testgen_ad_sin!();
|
||||||
// burn_autodiff::testgen_ad_softmax!();
|
burn_autodiff::testgen_ad_softmax!();
|
||||||
// burn_autodiff::testgen_ad_sqrt!();
|
burn_autodiff::testgen_ad_sqrt!();
|
||||||
// burn_autodiff::testgen_ad_abs!();
|
burn_autodiff::testgen_ad_abs!();
|
||||||
// burn_autodiff::testgen_ad_sub!();
|
burn_autodiff::testgen_ad_sub!();
|
||||||
// burn_autodiff::testgen_ad_tanh!();
|
burn_autodiff::testgen_ad_tanh!();
|
||||||
// burn_autodiff::testgen_ad_transpose!();
|
burn_autodiff::testgen_ad_transpose!();
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,19 @@ use burn_tensor::ops::ActivationOps;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||||
CandleBackend,
|
tensor, CandleBackend, CandleTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::base::FloatTensor;
|
||||||
|
|
||||||
impl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<CandleBackend<F, I>>
|
impl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<CandleBackend<F, I>>
|
||||||
for CandleBackend<F, I>
|
for CandleBackend<F, I>
|
||||||
{
|
{
|
||||||
|
fn gelu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
CandleTensor::new(tensor.tensor.gelu().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
CandleTensor::new(tensor.tensor.relu().unwrap())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,6 +98,5 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
|
||||||
ranges: [std::ops::Range<usize>; D2],
|
ranges: [std::ops::Range<usize>; D2],
|
||||||
value: CandleTensor<E, D1>,
|
value: CandleTensor<E, D1>,
|
||||||
) -> CandleTensor<E, D1> {
|
) -> CandleTensor<E, D1> {
|
||||||
// TODO: not trivial, because no view_ like in torch
|
panic!("slice_assign not supported by Candle")
|
||||||
todo!()
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
use burn_tensor::ops::{
|
use burn_tensor::{
|
||||||
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
ops::{ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps},
|
||||||
|
Shape,
|
||||||
};
|
};
|
||||||
|
use candle_core::ToUsize2;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||||
CandleBackend,
|
ops::base::reshape,
|
||||||
|
CandleBackend, CandleTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::base::{FloatTensor, IntTensor};
|
use super::base::{FloatTensor, IntTensor};
|
||||||
|
@ -12,13 +15,76 @@ use super::base::{FloatTensor, IntTensor};
|
||||||
impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
for CandleBackend<F, I>
|
for CandleBackend<F, I>
|
||||||
{
|
{
|
||||||
|
fn conv1d(
|
||||||
|
x: FloatTensor<Self, 3>,
|
||||||
|
weight: FloatTensor<Self, 3>,
|
||||||
|
bias: Option<FloatTensor<Self, 1>>,
|
||||||
|
options: ConvOptions<1>,
|
||||||
|
) -> FloatTensor<Self, 3> {
|
||||||
|
let conv = x
|
||||||
|
.tensor
|
||||||
|
.conv1d(
|
||||||
|
&weight.tensor,
|
||||||
|
options.padding[0],
|
||||||
|
options.stride[0],
|
||||||
|
options.dilation[0],
|
||||||
|
options.groups,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
CandleTensor::new(match bias {
|
||||||
|
Some(bias) => conv
|
||||||
|
.broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
|
||||||
|
.unwrap(),
|
||||||
|
None => conv,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
x: FloatTensor<Self, 4>,
|
x: FloatTensor<Self, 4>,
|
||||||
weight: FloatTensor<Self, 4>,
|
weight: FloatTensor<Self, 4>,
|
||||||
bias: Option<FloatTensor<Self, 1>>,
|
bias: Option<FloatTensor<Self, 1>>,
|
||||||
options: ConvOptions<2>,
|
options: ConvOptions<2>,
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
assert!(
|
||||||
|
options.dilation[0] == options.dilation[1]
|
||||||
|
&& options.padding[0] == options.padding[1]
|
||||||
|
&& options.stride[0] == options.stride[1],
|
||||||
|
"Candle does not support per dimension options in convolutions"
|
||||||
|
);
|
||||||
|
let conv = x
|
||||||
|
.tensor
|
||||||
|
.conv2d(
|
||||||
|
&weight.tensor,
|
||||||
|
options.padding[0],
|
||||||
|
options.stride[0],
|
||||||
|
options.dilation[0],
|
||||||
|
options.groups,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
CandleTensor::new(match bias {
|
||||||
|
Some(bias) => conv
|
||||||
|
.broadcast_add(
|
||||||
|
&bias
|
||||||
|
.tensor
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(2)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(3)
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
None => conv,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
x: FloatTensor<Self, 3>,
|
||||||
|
weight: FloatTensor<Self, 3>,
|
||||||
|
bias: Option<FloatTensor<Self, 1>>,
|
||||||
|
options: ConvTransposeOptions<1>,
|
||||||
|
) -> FloatTensor<Self, 3> {
|
||||||
|
panic!("Candle does not support conv_transpose1d")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose2d(
|
fn conv_transpose2d(
|
||||||
|
@ -27,7 +93,42 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
bias: Option<FloatTensor<Self, 1>>,
|
bias: Option<FloatTensor<Self, 1>>,
|
||||||
options: ConvTransposeOptions<2>,
|
options: ConvTransposeOptions<2>,
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
assert!(
|
||||||
|
options.dilation[0] == options.dilation[1]
|
||||||
|
&& options.padding[0] == options.padding[1]
|
||||||
|
&& options.padding_out[0] == options.padding_out[1]
|
||||||
|
&& options.stride[0] == options.stride[1],
|
||||||
|
"Candle does not support per dimension options in transposed convolutions"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
options.groups == 1,
|
||||||
|
"Candle does not support groups in transposed convolutions"
|
||||||
|
);
|
||||||
|
let conv_transpose = x
|
||||||
|
.tensor
|
||||||
|
.conv_transpose2d(
|
||||||
|
&weight.tensor,
|
||||||
|
options.padding[0],
|
||||||
|
options.padding_out[0],
|
||||||
|
options.stride[0],
|
||||||
|
options.dilation[0],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
CandleTensor::new(match bias {
|
||||||
|
Some(bias) => conv_transpose
|
||||||
|
.broadcast_add(
|
||||||
|
&bias
|
||||||
|
.tensor
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(2)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(3)
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
None => conv_transpose,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn avg_pool2d(
|
fn avg_pool2d(
|
||||||
|
@ -37,7 +138,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
padding: [usize; 2],
|
padding: [usize; 2],
|
||||||
count_include_pad: bool,
|
count_include_pad: bool,
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
assert!(
|
||||||
|
padding[0] == 0 && padding[1] == 0,
|
||||||
|
"Candle does not support padding in pooling"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
count_include_pad,
|
||||||
|
"Candle does not support excluding pad count in pooling"
|
||||||
|
);
|
||||||
|
CandleTensor::new(
|
||||||
|
x.tensor
|
||||||
|
.avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn avg_pool2d_backward(
|
fn avg_pool2d_backward(
|
||||||
|
@ -48,7 +161,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
padding: [usize; 2],
|
padding: [usize; 2],
|
||||||
count_include_pad: bool,
|
count_include_pad: bool,
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
panic!("avg_pool2d_backward is not supported by Candle")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d(
|
fn max_pool2d(
|
||||||
|
@ -58,7 +171,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
padding: [usize; 2],
|
padding: [usize; 2],
|
||||||
dilation: [usize; 2],
|
dilation: [usize; 2],
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
assert!(
|
||||||
|
padding[0] == 0 && padding[1] == 0,
|
||||||
|
"Candle does not support padding in pooling"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
dilation[0] == 1 && dilation[1] == 1,
|
||||||
|
"Candle does not support dilation in pooling"
|
||||||
|
);
|
||||||
|
CandleTensor::new(
|
||||||
|
x.tensor
|
||||||
|
.max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d_with_indices(
|
fn max_pool2d_with_indices(
|
||||||
|
@ -68,7 +193,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
padding: [usize; 2],
|
padding: [usize; 2],
|
||||||
dilation: [usize; 2],
|
dilation: [usize; 2],
|
||||||
) -> MaxPool2dWithIndices<CandleBackend<F, I>> {
|
) -> MaxPool2dWithIndices<CandleBackend<F, I>> {
|
||||||
todo!()
|
panic!("max_pool2d_with_indices is not supported by Candle")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d_with_indices_backward(
|
fn max_pool2d_with_indices_backward(
|
||||||
|
@ -80,20 +205,20 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
||||||
output_grad: FloatTensor<Self, 4>,
|
output_grad: FloatTensor<Self, 4>,
|
||||||
indices: IntTensor<Self, 4>,
|
indices: IntTensor<Self, 4>,
|
||||||
) -> MaxPool2dBackward<CandleBackend<F, I>> {
|
) -> MaxPool2dBackward<CandleBackend<F, I>> {
|
||||||
todo!()
|
panic!("max_pool2d_with_indices_backward is not supported by Candle")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn adaptive_avg_pool2d(
|
fn adaptive_avg_pool2d(
|
||||||
x: FloatTensor<Self, 4>,
|
x: FloatTensor<Self, 4>,
|
||||||
output_size: [usize; 2],
|
output_size: [usize; 2],
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
panic!("adaptive_avg_pool2 is not supported by Candle")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn adaptive_avg_pool2d_backward(
|
fn adaptive_avg_pool2d_backward(
|
||||||
x: FloatTensor<Self, 4>,
|
x: FloatTensor<Self, 4>,
|
||||||
grad: FloatTensor<Self, 4>,
|
grad: FloatTensor<Self, 4>,
|
||||||
) -> FloatTensor<Self, 4> {
|
) -> FloatTensor<Self, 4> {
|
||||||
todo!()
|
panic!("adaptive_avg_pool2d_backward is not supported by Candle")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -369,7 +369,12 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf<const D: usize>(tensor: FloatTensor<Self, D>, value: f32) -> FloatTensor<Self, D> {
|
fn powf<const D: usize>(tensor: FloatTensor<Self, D>, value: f32) -> FloatTensor<Self, D> {
|
||||||
panic!("powf not supported by Candle")
|
CandleTensor::new(
|
||||||
|
(tensor.tensor.log().unwrap() * value.elem::<f64>())
|
||||||
|
.unwrap()
|
||||||
|
.exp()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
|
|
@ -5,12 +5,14 @@ impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
|
||||||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
tensor.unary_ops(
|
tensor.unary_ops(
|
||||||
|mut tensor| tensor.gelu_("none"),
|
|mut tensor| tensor.gelu_("none"),
|
||||||
|tensor| tensor.gelu("none"),
|
|tensor| tensor.gelu("none"),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gelu_backward<const D: usize>(
|
fn gelu_backward<const D: usize>(
|
||||||
tensor: TchTensor<E, D>,
|
tensor: TchTensor<E, D>,
|
||||||
grad: TchTensor<E, D>,
|
grad: TchTensor<E, D>,
|
||||||
|
|
|
@ -68,6 +68,7 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_transpose!();
|
burn_tensor::testgen_transpose!();
|
||||||
|
|
||||||
// test stats
|
// test stats
|
||||||
burn_tensor::testgen_stats!();
|
burn_tensor::testgen_var!();
|
||||||
|
burn_tensor::testgen_display!();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_max_pool1d_simple() {
|
fn test_max_pool1d_simple() {
|
||||||
let kernel_size = 3;
|
let kernel_size = 3;
|
||||||
let padding = 1;
|
let padding = 0;
|
||||||
let stride = 1;
|
let stride = 1;
|
||||||
let dilation = 1;
|
let dilation = 1;
|
||||||
|
|
||||||
|
@ -18,8 +18,8 @@ mod tests {
|
||||||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||||
]]);
|
]]);
|
||||||
let y = TestTensor::from_floats([[
|
let y = TestTensor::from_floats([[
|
||||||
[0.9861, 0.9861, 0.5474, 0.4477, 0.8221, 0.8221],
|
[0.9861, 0.5474, 0.4477, 0.8221],
|
||||||
[0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689],
|
[0.949, 0.949, 0.949, 0.789],
|
||||||
]]);
|
]]);
|
||||||
|
|
||||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
||||||
|
@ -81,13 +81,13 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_max_pool1d_with_indices() {
|
fn test_max_pool1d_with_indices() {
|
||||||
let kernel_size = 2;
|
let kernel_size = 2;
|
||||||
let padding = 1;
|
let padding = 0;
|
||||||
let stride = 1;
|
let stride = 1;
|
||||||
let dilation = 1;
|
let dilation = 1;
|
||||||
|
|
||||||
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
|
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
|
||||||
let indices = Data::<IntElem, 3>::from([[[0, 1, 1, 3, 3]]]);
|
let indices = Data::<IntElem, 3>::from([[[1, 1, 3]]]);
|
||||||
let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]);
|
let y = TestTensor::from_floats([[[0.6386, 0.6386, 0.5742]]]);
|
||||||
|
|
||||||
let (output, output_indices) =
|
let (output, output_indices) =
|
||||||
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
|
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#[burn_tensor_testgen::testgen(stats)]
|
#[burn_tensor_testgen::testgen(display)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use burn_tensor::backend::Backend;
|
use burn_tensor::backend::Backend;
|
||||||
|
@ -7,17 +7,6 @@ mod tests {
|
||||||
type FloatElem = <TestBackend as Backend>::FloatElem;
|
type FloatElem = <TestBackend as Backend>::FloatElem;
|
||||||
type IntElem = <TestBackend as Backend>::IntElem;
|
type IntElem = <TestBackend as Backend>::IntElem;
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_var() {
|
|
||||||
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
|
||||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
|
||||||
|
|
||||||
let data_actual = tensor.var(1).into_data();
|
|
||||||
|
|
||||||
let data_expected = Data::from([[2.4892], [15.3333]]);
|
|
||||||
data_expected.assert_approx_eq(&data_actual, 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_display_2d_int_tensor() {
|
fn test_display_2d_int_tensor() {
|
||||||
let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
|
let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
|
|
@ -1 +1,2 @@
|
||||||
mod basic;
|
mod display;
|
||||||
|
mod var;
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
#[burn_tensor_testgen::testgen(var)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::backend::Backend;
|
||||||
|
use burn_tensor::{Data, Tensor};
|
||||||
|
|
||||||
|
type FloatElem = <TestBackend as Backend>::FloatElem;
|
||||||
|
type IntElem = <TestBackend as Backend>::IntElem;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_var() {
|
||||||
|
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||||
|
|
||||||
|
let data_actual = tensor.var(1).into_data();
|
||||||
|
|
||||||
|
let data_expected = Data::from([[2.4892], [15.3333]]);
|
||||||
|
data_expected.assert_approx_eq(&data_actual, 3);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue