mirror of https://github.com/tracel-ai/burn.git
feat(trait-TensorOps): add log1p (#160)
This commit is contained in:
parent
993a6e3095
commit
0b85cb0eed
|
@ -1031,6 +1031,33 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn log1p<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(Default, Debug)]
|
||||
struct Backward<B: Backend, const D: usize> {
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for Backward<B, D>
|
||||
{
|
||||
fn partial(
|
||||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let value = state.input.value();
|
||||
let value = B::div(&value.ones(), &B::add_scalar(&value, &1.to_elem()));
|
||||
B::mul(&state.output.grad(), &value)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::log(tensor.tensor_ref());
|
||||
let ops = Backward::<B, D>::default();
|
||||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
value: f32,
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#[burn_tensor_testgen::testgen(ad_log1p)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_log1p() {
|
||||
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.log1p());
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[60.2652, 72.3130], [60.2652, 72.3130]]), 3);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[21.9328, 23.4864], [23.8506, 25.9870]]), 3);
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ mod erf;
|
|||
mod exp;
|
||||
mod index;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxpool2d;
|
||||
|
@ -54,6 +55,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_exp!();
|
||||
burn_autodiff::testgen_ad_index!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
burn_autodiff::testgen_ad_log1p!();
|
||||
burn_autodiff::testgen_ad_mask!();
|
||||
burn_autodiff::testgen_ad_matmul!();
|
||||
burn_autodiff::testgen_ad_mul!();
|
||||
|
|
|
@ -41,7 +41,7 @@ where
|
|||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, dataloader)| {
|
||||
let dataloader_cloned = dataloader.clone();
|
||||
let dataloader_cloned = dataloader;
|
||||
let sender_cloned = sender.clone();
|
||||
|
||||
thread::spawn(move || {
|
||||
|
|
|
@ -75,7 +75,6 @@ impl<B: Backend> TransformerEncoder<B> {
|
|||
/// Create the module from the given configuration.
|
||||
pub fn new(config: &TransformerEncoderConfig) -> Self {
|
||||
let layers = (0..config.n_layers)
|
||||
.into_iter()
|
||||
.map(|_| TransformerEncoderLayer::new(config))
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ pub(crate) trait NdArrayElement:
|
|||
pub(crate) trait ExpElement {
|
||||
fn exp_elem(self) -> Self;
|
||||
fn log_elem(self) -> Self;
|
||||
fn log1p_elem(self) -> Self;
|
||||
fn pow_elem(self, value: f32) -> Self;
|
||||
fn sqrt_elem(self) -> Self;
|
||||
}
|
||||
|
@ -21,6 +22,9 @@ macro_rules! impl_exp_elem {
|
|||
fn log_elem(self) -> Self {
|
||||
$elem::ln(self)
|
||||
}
|
||||
fn log1p_elem(self) -> Self {
|
||||
$elem::ln_1p(self)
|
||||
}
|
||||
fn pow_elem(self, value: f32) -> Self {
|
||||
$elem::powf(self, value.into())
|
||||
}
|
||||
|
@ -39,6 +43,10 @@ macro_rules! impl_exp_elem {
|
|||
let tmp = $tmp::ln(self as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
fn log1p_elem(self) -> Self {
|
||||
let tmp = $tmp::ln_1p(self as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
fn pow_elem(self, value: f32) -> Self {
|
||||
let tmp = $tmp::powf(self as $tmp, value as $tmp);
|
||||
tmp as $elem
|
||||
|
|
|
@ -480,6 +480,12 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn log1p<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.mapv(|a| a.log1p_elem()).into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(tensor: &NdArrayTensor<E, D>, value: f32) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.mapv(|a| a.pow_elem(value)).into_shared();
|
||||
|
||||
|
|
|
@ -431,6 +431,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
to_tensor(tensor.tensor.log())
|
||||
}
|
||||
|
||||
fn log1p<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.log1p())
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(tensor: &TchTensor<E, D>, value: f32) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.pow_tensor_scalar(value as f64))
|
||||
}
|
||||
|
|
|
@ -86,6 +86,13 @@ where
|
|||
Self::new(B::log(&self.value))
|
||||
}
|
||||
|
||||
/// Applies the natural logarithm of one plus the input tensor, element-wise.
|
||||
///
|
||||
/// `y = log(x+1)`
|
||||
pub fn log1p(&self) -> Self {
|
||||
Self::new(B::log1p(&self.value))
|
||||
}
|
||||
|
||||
/// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
|
||||
///
|
||||
/// `y = erf(x)`
|
||||
|
|
|
@ -213,6 +213,7 @@ pub trait TensorOps<B: Backend> {
|
|||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn exp<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn log1p<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
|
||||
fn sqrt<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn cos<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
|
|
@ -25,6 +25,8 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_mask!();
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
#[burn_tensor_testgen::testgen(log)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_exp_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.log().into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[-f32::INFINITY, 0.0, std::f32::consts::LN_2],
|
||||
[1.0986, 1.3862, 1.6094],
|
||||
]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
#[burn_tensor_testgen::testgen(log1p)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_exp_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.log1p().into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[0.0, std::f32::consts::LN_2, 1.0986],
|
||||
[1.3862, 1.6094, 1.7917],
|
||||
]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -6,6 +6,8 @@ mod div;
|
|||
mod erf;
|
||||
mod exp;
|
||||
mod index;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod map_comparison;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
|
|
Loading…
Reference in New Issue