Implement tensor.recip() function to calculate elementwise reciprocals (#953)

This commit is contained in:
Zsombor 2023-11-15 15:17:32 +01:00 committed by GitHub
parent e882d41f8b
commit 4fc0c27e31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 241 additions and 1 deletions

View File

@ -435,6 +435,32 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
.stateless(B::neg(tensor.primitive))
}
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Recip;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Recip {
type State = B::TensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let tensor = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let tmp = B::powf(tensor, -2.0);
let value = B::neg(tmp);
B::mul(grad, value)
});
}
}
match Recip.prepare([tensor.node], [tensor.graph]).stateful() {
OpsKind::Tracked(prep) => {
prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)),
}
}
fn swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,

View File

@ -34,6 +34,7 @@ mod mul;
mod multithread;
mod neg;
mod pow;
mod recip;
mod relu;
mod reshape;
mod select;
@ -94,6 +95,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();

View File

@ -0,0 +1,20 @@
#[burn_tensor_testgen::testgen(ad_recip)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_recip() {
let data = Data::from([2.0, 5.0, 0.4]);
let tensor = TestAutodiffTensor::from_data(data).require_grad();
let tensor_out = tensor.clone().recip();
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5]));
grad.to_data()
.assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3);
}
}

View File

@ -126,6 +126,7 @@ Those operations are only available for `Float` tensors.
| `tensor.erf()` | `tensor.erf()` |
| `tensor.powf(value)` | `tensor.pow(value)` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.tanh()` | `tensor.tanh()` |

View File

@ -56,6 +56,7 @@ mod tests {
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
// burn_tensor::testgen_div!();
@ -133,6 +134,7 @@ mod tests {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();

View File

@ -442,4 +442,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
}
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
}

View File

@ -100,6 +100,11 @@ pub enum FloatOpsDescription<B: FusionBackend> {
(TensorDescription, Distribution<FloatElem<B>>),
Box<dyn Ops<B, Args = (TensorDescription, Distribution<FloatElem<B>>)>>,
),
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
Recip(
UnaryOpsDescription,
Box<dyn Ops<B, Args = UnaryOpsDescription>>,
),
}
/// Operation description specific to module.
@ -1252,6 +1257,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs),
FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input),
@ -1268,6 +1274,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles),

View File

@ -1391,6 +1391,21 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Recip, B::recip);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
UnaryOpsDescription {
input: tensor.into_description(),
out: out.to_description_out(),
},
Box::new(Recip::<D>),
)));
out
}
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::erf);

View File

@ -134,7 +134,7 @@ represent the corresponding Burn Op.
| [RandomUniform][128] | ❌ | ✅ |
| [RandomUniformLike][129] | ❌ | ✅ |
| [Range][130] | ❌ | ✅ |
| [Reciprocal][131] | ❌ | ❌ |
| [Reciprocal][131] | ✅ | ✅ |
| [ReduceL][132] | ❌ | ❌ |
| [ReduceLogSum][133] | ❌ | ❌ |
| [ReduceLogSumExp][134] | ❌ | ❌ |

View File

@ -27,6 +27,7 @@ fn main() {
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")

View File

@ -34,6 +34,7 @@ include_models!(
log_softmax,
maxpool2d,
mul,
recip,
relu,
reshape,
sigmoid,
@ -591,4 +592,17 @@ mod tests {
let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn recip() {
// Initialize the model
let model = recip::Model::<Backend>::new();
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
let output = model.forward(input);
// data from pyTorch
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
}

View File

@ -0,0 +1,17 @@
pytorch2.1.0:‰
0
onnx::Reciprocal_01 /Reciprocal"
Reciprocal
main_graphZ,
onnx::Reciprocal_0




b
1




B

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/recip/recip.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return x.reciprocal()
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "recip.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)
torch.onnx.export(model, (dummy_input), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -26,6 +26,7 @@ pub enum UnaryNodeKind {
LogSoftmax,
Softmax,
Relu,
Reciprocal,
Sigmoid,
Tanh,
Transpose,
@ -40,6 +41,7 @@ impl UnaryNodeKind {
Self::LogSoftmax => "log_softmax",
Self::Softmax => "softmax",
Self::Relu => "relu",
Self::Reciprocal => "reciprocal",
Self::Sigmoid => "sigmoid",
Self::Tanh => "tanh",
Self::Transpose => "transpose",
@ -141,6 +143,11 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
}
pub(crate) fn reciprocal(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.recip() };
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
}
/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
@ -334,6 +341,25 @@ mod tests {
);
}
#[test]
fn test_unary_codegen_reciprocal() {
one_node_graph(
UnaryNode::reciprocal(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.recip();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_cast() {
one_node_graph(

View File

@ -79,6 +79,7 @@ pub fn dim_inference(
NodeType::Erf => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),

View File

@ -250,6 +250,7 @@ impl ONNXGraph {
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
NodeType::Concat => graph.register(Self::concat_conversion(node)),
@ -447,6 +448,13 @@ impl ONNXGraph {
UnaryNode::sigmoid(input, output)
}
fn reciprocal_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::reciprocal(input, output)
}
fn log_softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();

View File

@ -181,6 +181,13 @@ where
NdArrayTensor { array }
}
pub fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.map(|x| 1.elem::<E>() / *x);
let array = array.into_shared();
NdArrayTensor { array }
}
pub fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
let data = Data::from([tensor.array.mean().unwrap()]);
NdArrayTensor::from_data(data)

View File

@ -127,6 +127,10 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
}
fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::recip(tensor)
}
fn swap_dims<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim1: usize,

View File

@ -172,6 +172,10 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
}
fn recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.reciprocal())
}
fn swap_dims<const D: usize>(
tensor: TchTensor<E, D>,
dim1: usize,

View File

@ -67,6 +67,11 @@ where
Self::new(B::powf(self.primitive, value))
}
/// Applies element wise reciprocal operation.
pub fn recip(self) -> Self {
Self::new(B::recip(self.primitive))
}
/// Applies element wise root square operation.
pub fn sqrt(self) -> Self {
Self::new(B::sqrt(self.primitive))

View File

@ -410,6 +410,9 @@ pub trait TensorOps<B: Backend> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
}
/// Calculates the reciprocals elementwise
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Transposes a tensor.
///
/// # Arguments

View File

@ -60,6 +60,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_one_hot!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_random!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();

View File

@ -28,6 +28,7 @@ mod neg;
mod one_hot;
mod powf;
mod random;
mod recip;
mod repeat;
mod reshape;
mod select;

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(recip)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_recip_ops() {
let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.recip().into_data();
let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}

View File

@ -510,4 +510,17 @@ where
) -> FloatTensor<Self, D> {
kernel::clamp(tensor, min, max)
}
fn recip<const D: usize>(
tensor: FloatTensor<Wgpu<G, F, I>, D>,
) -> FloatTensor<Wgpu<G, F, I>, D> {
unary!(Recip, func "1.0 /");
unary_inplace!(RecipInplace, func "1.0 /");
if tensor.can_mut() {
return unary_inplace_default::<RecipInplace, F, D>(tensor);
}
unary_default::<Recip, F, D>(tensor)
}
}