mirror of https://github.com/tracel-ai/burn.git
Implement tensor.recip() function to calculate elementwise reciprocals (#953)
This commit is contained in:
parent
e882d41f8b
commit
4fc0c27e31
|
@ -435,6 +435,32 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
|||
.stateless(B::neg(tensor.primitive))
|
||||
}
|
||||
|
||||
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
#[derive(Debug)]
|
||||
struct Recip;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Recip {
|
||||
type State = B::TensorPrimitive<D>;
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let tensor = ops.state;
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let tmp = B::powf(tensor, -2.0);
|
||||
let value = B::neg(tmp);
|
||||
|
||||
B::mul(grad, value)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Recip.prepare([tensor.node], [tensor.graph]).stateful() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim1: usize,
|
||||
|
|
|
@ -34,6 +34,7 @@ mod mul;
|
|||
mod multithread;
|
||||
mod neg;
|
||||
mod pow;
|
||||
mod recip;
|
||||
mod relu;
|
||||
mod reshape;
|
||||
mod select;
|
||||
|
@ -94,6 +95,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_mul!();
|
||||
burn_autodiff::testgen_ad_neg!();
|
||||
burn_autodiff::testgen_ad_powf!();
|
||||
burn_autodiff::testgen_ad_recip!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_sin!();
|
||||
burn_autodiff::testgen_ad_softmax!();
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#[burn_tensor_testgen::testgen(ad_recip)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_recip() {
|
||||
let data = Data::from([2.0, 5.0, 0.4]);
|
||||
|
||||
let tensor = TestAutodiffTensor::from_data(data).require_grad();
|
||||
let tensor_out = tensor.clone().recip();
|
||||
|
||||
let grads = tensor_out.backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5]));
|
||||
grad.to_data()
|
||||
.assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3);
|
||||
}
|
||||
}
|
|
@ -126,6 +126,7 @@ Those operations are only available for `Float` tensors.
|
|||
| `tensor.erf()` | `tensor.erf()` |
|
||||
| `tensor.powf(value)` | `tensor.pow(value)` |
|
||||
| `tensor.sqrt()` | `tensor.sqrt()` |
|
||||
| `tensor.recip()` | `tensor.reciprocal()` |
|
||||
| `tensor.cos()` | `tensor.cos()` |
|
||||
| `tensor.sin()` | `tensor.sin()` |
|
||||
| `tensor.tanh()` | `tensor.tanh()` |
|
||||
|
|
|
@ -56,6 +56,7 @@ mod tests {
|
|||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cast!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_recip!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_cos!();
|
||||
// burn_tensor::testgen_div!();
|
||||
|
@ -133,6 +134,7 @@ mod tests {
|
|||
burn_autodiff::testgen_ad_mul!();
|
||||
burn_autodiff::testgen_ad_neg!();
|
||||
burn_autodiff::testgen_ad_powf!();
|
||||
burn_autodiff::testgen_ad_recip!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_sin!();
|
||||
burn_autodiff::testgen_ad_softmax!();
|
||||
|
|
|
@ -442,4 +442,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
|
|||
) -> FloatTensor<Self, D> {
|
||||
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
|
||||
}
|
||||
|
||||
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
CandleTensor::new(tensor.tensor.recip().unwrap())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -100,6 +100,11 @@ pub enum FloatOpsDescription<B: FusionBackend> {
|
|||
(TensorDescription, Distribution<FloatElem<B>>),
|
||||
Box<dyn Ops<B, Args = (TensorDescription, Distribution<FloatElem<B>>)>>,
|
||||
),
|
||||
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
|
||||
Recip(
|
||||
UnaryOpsDescription,
|
||||
Box<dyn Ops<B, Args = UnaryOpsDescription>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Operation description specific to module.
|
||||
|
@ -1252,6 +1257,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
|
|||
FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input),
|
||||
FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input),
|
||||
FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input),
|
||||
FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input),
|
||||
FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs),
|
||||
FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input),
|
||||
FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input),
|
||||
|
@ -1268,6 +1274,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
|
|||
FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles),
|
||||
FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles),
|
||||
FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles),
|
||||
FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles),
|
||||
FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles),
|
||||
FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles),
|
||||
FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles),
|
||||
|
|
|
@ -1391,6 +1391,21 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(Recip, B::recip);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
|
||||
UnaryOpsDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
},
|
||||
Box::new(Recip::<D>),
|
||||
)));
|
||||
out
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(TanhOps, B::erf);
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ represent the corresponding Burn Op.
|
|||
| [RandomUniform][128] | ❌ | ✅ |
|
||||
| [RandomUniformLike][129] | ❌ | ✅ |
|
||||
| [Range][130] | ❌ | ✅ |
|
||||
| [Reciprocal][131] | ❌ | ❌ |
|
||||
| [Reciprocal][131] | ✅ | ✅ |
|
||||
| [ReduceL][132] | ❌ | ❌ |
|
||||
| [ReduceLogSum][133] | ❌ | ❌ |
|
||||
| [ReduceLogSumExp][134] | ❌ | ❌ |
|
||||
|
|
|
@ -27,6 +27,7 @@ fn main() {
|
|||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/reshape/reshape.onnx")
|
||||
.input("tests/sigmoid/sigmoid.onnx")
|
||||
|
|
|
@ -34,6 +34,7 @@ include_models!(
|
|||
log_softmax,
|
||||
maxpool2d,
|
||||
mul,
|
||||
recip,
|
||||
relu,
|
||||
reshape,
|
||||
sigmoid,
|
||||
|
@ -591,4 +592,17 @@ mod tests {
|
|||
let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]);
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recip() {
|
||||
// Initialize the model
|
||||
let model = recip::Model::<Backend>::new();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
|
||||
let output = model.forward(input);
|
||||
// data from pyTorch
|
||||
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
pytorch2.1.0:‰
|
||||
0
|
||||
onnx::Reciprocal_01/Reciprocal"
|
||||
Reciprocal
|
||||
main_graphZ,
|
||||
onnx::Reciprocal_0
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/recip/recip.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.reciprocal()
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "recip.onnx"
|
||||
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||
|
||||
torch.onnx.export(model, (dummy_input), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
|
||||
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -26,6 +26,7 @@ pub enum UnaryNodeKind {
|
|||
LogSoftmax,
|
||||
Softmax,
|
||||
Relu,
|
||||
Reciprocal,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
Transpose,
|
||||
|
@ -40,6 +41,7 @@ impl UnaryNodeKind {
|
|||
Self::LogSoftmax => "log_softmax",
|
||||
Self::Softmax => "softmax",
|
||||
Self::Relu => "relu",
|
||||
Self::Reciprocal => "reciprocal",
|
||||
Self::Sigmoid => "sigmoid",
|
||||
Self::Tanh => "tanh",
|
||||
Self::Transpose => "transpose",
|
||||
|
@ -141,6 +143,11 @@ impl UnaryNode {
|
|||
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn reciprocal(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.recip() };
|
||||
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
|
||||
}
|
||||
|
||||
/// Casts the input to the output type.
|
||||
///
|
||||
/// Currently this function only supports the following conversions:
|
||||
|
@ -334,6 +341,25 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_reciprocal() {
|
||||
one_node_graph(
|
||||
UnaryNode::reciprocal(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.recip();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_cast() {
|
||||
one_node_graph(
|
||||
|
|
|
@ -79,6 +79,7 @@ pub fn dim_inference(
|
|||
NodeType::Erf => same_as_input(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Tanh => same_as_input(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::ReduceMean => mean_update_outputs(node),
|
||||
NodeType::Constant => constant_update_outputs(node),
|
||||
|
|
|
@ -250,6 +250,7 @@ impl ONNXGraph {
|
|||
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
|
||||
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
||||
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
||||
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
||||
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
||||
|
@ -447,6 +448,13 @@ impl ONNXGraph {
|
|||
UnaryNode::sigmoid(input, output)
|
||||
}
|
||||
|
||||
fn reciprocal_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::reciprocal(input, output)
|
||||
}
|
||||
|
||||
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
|
|
@ -181,6 +181,13 @@ where
|
|||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.map(|x| 1.elem::<E>() / *x);
|
||||
let array = array.into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
|
||||
let data = Data::from([tensor.array.mean().unwrap()]);
|
||||
NdArrayTensor::from_data(data)
|
||||
|
|
|
@ -127,6 +127,10 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
|
|||
Self::mul_scalar(tensor, (-1f32).elem::<E>())
|
||||
}
|
||||
|
||||
fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::recip(tensor)
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim1: usize,
|
||||
|
|
|
@ -172,6 +172,10 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
|
|||
Self::mul_scalar(tensor, (-1f32).elem::<E>())
|
||||
}
|
||||
|
||||
fn recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
TchTensor::new(tensor.tensor.reciprocal())
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim1: usize,
|
||||
|
|
|
@ -67,6 +67,11 @@ where
|
|||
Self::new(B::powf(self.primitive, value))
|
||||
}
|
||||
|
||||
/// Applies element wise reciprocal operation.
|
||||
pub fn recip(self) -> Self {
|
||||
Self::new(B::recip(self.primitive))
|
||||
}
|
||||
|
||||
/// Applies element wise root square operation.
|
||||
pub fn sqrt(self) -> Self {
|
||||
Self::new(B::sqrt(self.primitive))
|
||||
|
|
|
@ -410,6 +410,9 @@ pub trait TensorOps<B: Backend> {
|
|||
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
|
||||
}
|
||||
|
||||
/// Calculates the reciprocals elementwise
|
||||
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
|
||||
|
||||
/// Transposes a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -60,6 +60,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_one_hot!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_random!();
|
||||
burn_tensor::testgen_recip!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_select!();
|
||||
|
|
|
@ -28,6 +28,7 @@ mod neg;
|
|||
mod one_hot;
|
||||
mod powf;
|
||||
mod random;
|
||||
mod recip;
|
||||
mod repeat;
|
||||
mod reshape;
|
||||
mod select;
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(recip)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_recip_ops() {
|
||||
let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.recip().into_data();
|
||||
|
||||
let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -510,4 +510,17 @@ where
|
|||
) -> FloatTensor<Self, D> {
|
||||
kernel::clamp(tensor, min, max)
|
||||
}
|
||||
|
||||
fn recip<const D: usize>(
|
||||
tensor: FloatTensor<Wgpu<G, F, I>, D>,
|
||||
) -> FloatTensor<Wgpu<G, F, I>, D> {
|
||||
unary!(Recip, func "1.0 /");
|
||||
unary_inplace!(RecipInplace, func "1.0 /");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<RecipInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Recip, F, D>(tensor)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue