mirror of https://github.com/tracel-ai/burn.git
Implement tensor.recip() function to calculate elementwise reciprocals (#953)
This commit is contained in:
parent
e882d41f8b
commit
4fc0c27e31
|
@ -435,6 +435,32 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
||||||
.stateless(B::neg(tensor.primitive))
|
.stateless(B::neg(tensor.primitive))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Recip;
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> Backward<B, D, 1> for Recip {
|
||||||
|
type State = B::TensorPrimitive<D>;
|
||||||
|
|
||||||
|
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||||
|
let tensor = ops.state;
|
||||||
|
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||||
|
let tmp = B::powf(tensor, -2.0);
|
||||||
|
let value = B::neg(tmp);
|
||||||
|
|
||||||
|
B::mul(grad, value)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match Recip.prepare([tensor.node], [tensor.graph]).stateful() {
|
||||||
|
OpsKind::Tracked(prep) => {
|
||||||
|
prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive))
|
||||||
|
}
|
||||||
|
OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn swap_dims<const D: usize>(
|
fn swap_dims<const D: usize>(
|
||||||
tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
dim1: usize,
|
dim1: usize,
|
||||||
|
|
|
@ -34,6 +34,7 @@ mod mul;
|
||||||
mod multithread;
|
mod multithread;
|
||||||
mod neg;
|
mod neg;
|
||||||
mod pow;
|
mod pow;
|
||||||
|
mod recip;
|
||||||
mod relu;
|
mod relu;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
mod select;
|
mod select;
|
||||||
|
@ -94,6 +95,7 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_mul!();
|
burn_autodiff::testgen_ad_mul!();
|
||||||
burn_autodiff::testgen_ad_neg!();
|
burn_autodiff::testgen_ad_neg!();
|
||||||
burn_autodiff::testgen_ad_powf!();
|
burn_autodiff::testgen_ad_powf!();
|
||||||
|
burn_autodiff::testgen_ad_recip!();
|
||||||
burn_autodiff::testgen_ad_reshape!();
|
burn_autodiff::testgen_ad_reshape!();
|
||||||
burn_autodiff::testgen_ad_sin!();
|
burn_autodiff::testgen_ad_sin!();
|
||||||
burn_autodiff::testgen_ad_softmax!();
|
burn_autodiff::testgen_ad_softmax!();
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_recip)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::Data;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_recip() {
|
||||||
|
let data = Data::from([2.0, 5.0, 0.4]);
|
||||||
|
|
||||||
|
let tensor = TestAutodiffTensor::from_data(data).require_grad();
|
||||||
|
let tensor_out = tensor.clone().recip();
|
||||||
|
|
||||||
|
let grads = tensor_out.backward();
|
||||||
|
let grad = tensor.grad(&grads).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5]));
|
||||||
|
grad.to_data()
|
||||||
|
.assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -126,6 +126,7 @@ Those operations are only available for `Float` tensors.
|
||||||
| `tensor.erf()` | `tensor.erf()` |
|
| `tensor.erf()` | `tensor.erf()` |
|
||||||
| `tensor.powf(value)` | `tensor.pow(value)` |
|
| `tensor.powf(value)` | `tensor.pow(value)` |
|
||||||
| `tensor.sqrt()` | `tensor.sqrt()` |
|
| `tensor.sqrt()` | `tensor.sqrt()` |
|
||||||
|
| `tensor.recip()` | `tensor.reciprocal()` |
|
||||||
| `tensor.cos()` | `tensor.cos()` |
|
| `tensor.cos()` | `tensor.cos()` |
|
||||||
| `tensor.sin()` | `tensor.sin()` |
|
| `tensor.sin()` | `tensor.sin()` |
|
||||||
| `tensor.tanh()` | `tensor.tanh()` |
|
| `tensor.tanh()` | `tensor.tanh()` |
|
||||||
|
|
|
@ -56,6 +56,7 @@ mod tests {
|
||||||
burn_tensor::testgen_arg!();
|
burn_tensor::testgen_arg!();
|
||||||
burn_tensor::testgen_cast!();
|
burn_tensor::testgen_cast!();
|
||||||
burn_tensor::testgen_cat!();
|
burn_tensor::testgen_cat!();
|
||||||
|
burn_tensor::testgen_recip!();
|
||||||
burn_tensor::testgen_clamp!();
|
burn_tensor::testgen_clamp!();
|
||||||
burn_tensor::testgen_cos!();
|
burn_tensor::testgen_cos!();
|
||||||
// burn_tensor::testgen_div!();
|
// burn_tensor::testgen_div!();
|
||||||
|
@ -133,6 +134,7 @@ mod tests {
|
||||||
burn_autodiff::testgen_ad_mul!();
|
burn_autodiff::testgen_ad_mul!();
|
||||||
burn_autodiff::testgen_ad_neg!();
|
burn_autodiff::testgen_ad_neg!();
|
||||||
burn_autodiff::testgen_ad_powf!();
|
burn_autodiff::testgen_ad_powf!();
|
||||||
|
burn_autodiff::testgen_ad_recip!();
|
||||||
burn_autodiff::testgen_ad_reshape!();
|
burn_autodiff::testgen_ad_reshape!();
|
||||||
burn_autodiff::testgen_ad_sin!();
|
burn_autodiff::testgen_ad_sin!();
|
||||||
burn_autodiff::testgen_ad_softmax!();
|
burn_autodiff::testgen_ad_softmax!();
|
||||||
|
|
|
@ -442,4 +442,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
|
||||||
) -> FloatTensor<Self, D> {
|
) -> FloatTensor<Self, D> {
|
||||||
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
|
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
CandleTensor::new(tensor.tensor.recip().unwrap())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,6 +100,11 @@ pub enum FloatOpsDescription<B: FusionBackend> {
|
||||||
(TensorDescription, Distribution<FloatElem<B>>),
|
(TensorDescription, Distribution<FloatElem<B>>),
|
||||||
Box<dyn Ops<B, Args = (TensorDescription, Distribution<FloatElem<B>>)>>,
|
Box<dyn Ops<B, Args = (TensorDescription, Distribution<FloatElem<B>>)>>,
|
||||||
),
|
),
|
||||||
|
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
|
||||||
|
Recip(
|
||||||
|
UnaryOpsDescription,
|
||||||
|
Box<dyn Ops<B, Args = UnaryOpsDescription>>,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation description specific to module.
|
/// Operation description specific to module.
|
||||||
|
@ -1252,6 +1257,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
|
||||||
FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input),
|
FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input),
|
||||||
FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input),
|
FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input),
|
||||||
FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input),
|
FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input),
|
||||||
|
FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input),
|
||||||
FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs),
|
FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs),
|
||||||
FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input),
|
FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input),
|
||||||
FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input),
|
FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input),
|
||||||
|
@ -1268,6 +1274,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
|
||||||
FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles),
|
FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles),
|
||||||
FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles),
|
FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles),
|
||||||
FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles),
|
FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles),
|
||||||
|
FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles),
|
||||||
FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles),
|
FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles),
|
||||||
FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles),
|
FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles),
|
||||||
FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles),
|
FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles),
|
||||||
|
|
|
@ -1391,6 +1391,21 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
unary_float_ops!(Recip, B::recip);
|
||||||
|
|
||||||
|
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||||
|
out.client
|
||||||
|
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
|
||||||
|
UnaryOpsDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
},
|
||||||
|
Box::new(Recip::<D>),
|
||||||
|
)));
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unary_float_ops!(TanhOps, B::erf);
|
unary_float_ops!(TanhOps, B::erf);
|
||||||
|
|
||||||
|
|
|
@ -134,7 +134,7 @@ represent the corresponding Burn Op.
|
||||||
| [RandomUniform][128] | ❌ | ✅ |
|
| [RandomUniform][128] | ❌ | ✅ |
|
||||||
| [RandomUniformLike][129] | ❌ | ✅ |
|
| [RandomUniformLike][129] | ❌ | ✅ |
|
||||||
| [Range][130] | ❌ | ✅ |
|
| [Range][130] | ❌ | ✅ |
|
||||||
| [Reciprocal][131] | ❌ | ❌ |
|
| [Reciprocal][131] | ✅ | ✅ |
|
||||||
| [ReduceL][132] | ❌ | ❌ |
|
| [ReduceL][132] | ❌ | ❌ |
|
||||||
| [ReduceLogSum][133] | ❌ | ❌ |
|
| [ReduceLogSum][133] | ❌ | ❌ |
|
||||||
| [ReduceLogSumExp][134] | ❌ | ❌ |
|
| [ReduceLogSumExp][134] | ❌ | ❌ |
|
||||||
|
|
|
@ -27,6 +27,7 @@ fn main() {
|
||||||
.input("tests/log_softmax/log_softmax.onnx")
|
.input("tests/log_softmax/log_softmax.onnx")
|
||||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||||
.input("tests/mul/mul.onnx")
|
.input("tests/mul/mul.onnx")
|
||||||
|
.input("tests/recip/recip.onnx")
|
||||||
.input("tests/relu/relu.onnx")
|
.input("tests/relu/relu.onnx")
|
||||||
.input("tests/reshape/reshape.onnx")
|
.input("tests/reshape/reshape.onnx")
|
||||||
.input("tests/sigmoid/sigmoid.onnx")
|
.input("tests/sigmoid/sigmoid.onnx")
|
||||||
|
|
|
@ -34,6 +34,7 @@ include_models!(
|
||||||
log_softmax,
|
log_softmax,
|
||||||
maxpool2d,
|
maxpool2d,
|
||||||
mul,
|
mul,
|
||||||
|
recip,
|
||||||
relu,
|
relu,
|
||||||
reshape,
|
reshape,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
|
@ -591,4 +592,17 @@ mod tests {
|
||||||
let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]);
|
let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]);
|
||||||
output.to_data().assert_approx_eq(&expected, 4);
|
output.to_data().assert_approx_eq(&expected, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn recip() {
|
||||||
|
// Initialize the model
|
||||||
|
let model = recip::Model::<Backend>::new();
|
||||||
|
|
||||||
|
// Run the model
|
||||||
|
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
|
||||||
|
let output = model.forward(input);
|
||||||
|
// data from pyTorch
|
||||||
|
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
|
||||||
|
output.to_data().assert_approx_eq(&expected, 4);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
pytorch2.1.0:‰
|
||||||
|
0
|
||||||
|
onnx::Reciprocal_01/Reciprocal"
|
||||||
|
Reciprocal
|
||||||
|
main_graphZ,
|
||||||
|
onnx::Reciprocal_0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
b
|
||||||
|
1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
B
|
|
@ -0,0 +1,42 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# used to generate model: onnx-tests/tests/recip/recip.onnx
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# Export to onnx
|
||||||
|
model = Model()
|
||||||
|
model.eval()
|
||||||
|
device = torch.device("cpu")
|
||||||
|
onnx_name = "recip.onnx"
|
||||||
|
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||||
|
|
||||||
|
torch.onnx.export(model, (dummy_input), onnx_name,
|
||||||
|
verbose=False, opset_version=16)
|
||||||
|
|
||||||
|
print("Finished exporting model to {}".format(onnx_name))
|
||||||
|
|
||||||
|
# Output some test data for use in the test
|
||||||
|
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
|
||||||
|
|
||||||
|
print("Test input data: {}".format(test_input))
|
||||||
|
output = model.forward(test_input)
|
||||||
|
print("Test output data: {}".format(output))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -26,6 +26,7 @@ pub enum UnaryNodeKind {
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
Softmax,
|
Softmax,
|
||||||
Relu,
|
Relu,
|
||||||
|
Reciprocal,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
Tanh,
|
Tanh,
|
||||||
Transpose,
|
Transpose,
|
||||||
|
@ -40,6 +41,7 @@ impl UnaryNodeKind {
|
||||||
Self::LogSoftmax => "log_softmax",
|
Self::LogSoftmax => "log_softmax",
|
||||||
Self::Softmax => "softmax",
|
Self::Softmax => "softmax",
|
||||||
Self::Relu => "relu",
|
Self::Relu => "relu",
|
||||||
|
Self::Reciprocal => "reciprocal",
|
||||||
Self::Sigmoid => "sigmoid",
|
Self::Sigmoid => "sigmoid",
|
||||||
Self::Tanh => "tanh",
|
Self::Tanh => "tanh",
|
||||||
Self::Transpose => "transpose",
|
Self::Transpose => "transpose",
|
||||||
|
@ -141,6 +143,11 @@ impl UnaryNode {
|
||||||
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
|
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn reciprocal(input: Type, output: Type) -> Self {
|
||||||
|
let function = move |input| quote! { #input.recip() };
|
||||||
|
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
|
||||||
|
}
|
||||||
|
|
||||||
/// Casts the input to the output type.
|
/// Casts the input to the output type.
|
||||||
///
|
///
|
||||||
/// Currently this function only supports the following conversions:
|
/// Currently this function only supports the following conversions:
|
||||||
|
@ -334,6 +341,25 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unary_codegen_reciprocal() {
|
||||||
|
one_node_graph(
|
||||||
|
UnaryNode::reciprocal(
|
||||||
|
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||||
|
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||||
|
),
|
||||||
|
quote! {
|
||||||
|
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
|
let tensor2 = tensor1.recip();
|
||||||
|
|
||||||
|
tensor2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
vec!["tensor1".to_string()],
|
||||||
|
vec!["tensor2".to_string()],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unary_codegen_cast() {
|
fn test_unary_codegen_cast() {
|
||||||
one_node_graph(
|
one_node_graph(
|
||||||
|
|
|
@ -79,6 +79,7 @@ pub fn dim_inference(
|
||||||
NodeType::Erf => same_as_input(node),
|
NodeType::Erf => same_as_input(node),
|
||||||
NodeType::Sqrt => same_as_input(node),
|
NodeType::Sqrt => same_as_input(node),
|
||||||
NodeType::Tanh => same_as_input(node),
|
NodeType::Tanh => same_as_input(node),
|
||||||
|
NodeType::Reciprocal => same_as_input(node),
|
||||||
NodeType::Softmax => same_as_input(node),
|
NodeType::Softmax => same_as_input(node),
|
||||||
NodeType::ReduceMean => mean_update_outputs(node),
|
NodeType::ReduceMean => mean_update_outputs(node),
|
||||||
NodeType::Constant => constant_update_outputs(node),
|
NodeType::Constant => constant_update_outputs(node),
|
||||||
|
|
|
@ -250,6 +250,7 @@ impl ONNXGraph {
|
||||||
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
|
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
|
||||||
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
||||||
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
||||||
|
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
||||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||||
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
||||||
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
||||||
|
@ -447,6 +448,13 @@ impl ONNXGraph {
|
||||||
UnaryNode::sigmoid(input, output)
|
UnaryNode::sigmoid(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reciprocal_conversion(node: Node) -> UnaryNode {
|
||||||
|
let input = node.inputs.get(0).unwrap().to_type();
|
||||||
|
let output = node.outputs.get(0).unwrap().to_type();
|
||||||
|
|
||||||
|
UnaryNode::reciprocal(input, output)
|
||||||
|
}
|
||||||
|
|
||||||
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.get(0).unwrap().to_type();
|
let input = node.inputs.get(0).unwrap().to_type();
|
||||||
let output = node.outputs.get(0).unwrap().to_type();
|
let output = node.outputs.get(0).unwrap().to_type();
|
||||||
|
|
|
@ -181,6 +181,13 @@ where
|
||||||
NdArrayTensor { array }
|
NdArrayTensor { array }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||||
|
let array = tensor.array.map(|x| 1.elem::<E>() / *x);
|
||||||
|
let array = array.into_shared();
|
||||||
|
|
||||||
|
NdArrayTensor { array }
|
||||||
|
}
|
||||||
|
|
||||||
pub fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
|
pub fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
|
||||||
let data = Data::from([tensor.array.mean().unwrap()]);
|
let data = Data::from([tensor.array.mean().unwrap()]);
|
||||||
NdArrayTensor::from_data(data)
|
NdArrayTensor::from_data(data)
|
||||||
|
|
|
@ -127,6 +127,10 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
|
||||||
Self::mul_scalar(tensor, (-1f32).elem::<E>())
|
Self::mul_scalar(tensor, (-1f32).elem::<E>())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||||
|
NdArrayMathOps::recip(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
fn swap_dims<const D: usize>(
|
fn swap_dims<const D: usize>(
|
||||||
tensor: NdArrayTensor<E, D>,
|
tensor: NdArrayTensor<E, D>,
|
||||||
dim1: usize,
|
dim1: usize,
|
||||||
|
|
|
@ -172,6 +172,10 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
|
||||||
Self::mul_scalar(tensor, (-1f32).elem::<E>())
|
Self::mul_scalar(tensor, (-1f32).elem::<E>())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
|
TchTensor::new(tensor.tensor.reciprocal())
|
||||||
|
}
|
||||||
|
|
||||||
fn swap_dims<const D: usize>(
|
fn swap_dims<const D: usize>(
|
||||||
tensor: TchTensor<E, D>,
|
tensor: TchTensor<E, D>,
|
||||||
dim1: usize,
|
dim1: usize,
|
||||||
|
|
|
@ -67,6 +67,11 @@ where
|
||||||
Self::new(B::powf(self.primitive, value))
|
Self::new(B::powf(self.primitive, value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies element wise reciprocal operation.
|
||||||
|
pub fn recip(self) -> Self {
|
||||||
|
Self::new(B::recip(self.primitive))
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies element wise root square operation.
|
/// Applies element wise root square operation.
|
||||||
pub fn sqrt(self) -> Self {
|
pub fn sqrt(self) -> Self {
|
||||||
Self::new(B::sqrt(self.primitive))
|
Self::new(B::sqrt(self.primitive))
|
||||||
|
|
|
@ -410,6 +410,9 @@ pub trait TensorOps<B: Backend> {
|
||||||
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
|
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Calculates the reciprocals elementwise
|
||||||
|
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
|
||||||
|
|
||||||
/// Transposes a tensor.
|
/// Transposes a tensor.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
|
@ -60,6 +60,7 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_one_hot!();
|
burn_tensor::testgen_one_hot!();
|
||||||
burn_tensor::testgen_powf!();
|
burn_tensor::testgen_powf!();
|
||||||
burn_tensor::testgen_random!();
|
burn_tensor::testgen_random!();
|
||||||
|
burn_tensor::testgen_recip!();
|
||||||
burn_tensor::testgen_repeat!();
|
burn_tensor::testgen_repeat!();
|
||||||
burn_tensor::testgen_reshape!();
|
burn_tensor::testgen_reshape!();
|
||||||
burn_tensor::testgen_select!();
|
burn_tensor::testgen_select!();
|
||||||
|
|
|
@ -28,6 +28,7 @@ mod neg;
|
||||||
mod one_hot;
|
mod one_hot;
|
||||||
mod powf;
|
mod powf;
|
||||||
mod random;
|
mod random;
|
||||||
|
mod recip;
|
||||||
mod repeat;
|
mod repeat;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
mod select;
|
mod select;
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
#[burn_tensor_testgen::testgen(recip)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Data, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_recip_ops() {
|
||||||
|
let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||||
|
|
||||||
|
let data_actual = tensor.recip().into_data();
|
||||||
|
|
||||||
|
let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]);
|
||||||
|
data_expected.assert_approx_eq(&data_actual, 3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -510,4 +510,17 @@ where
|
||||||
) -> FloatTensor<Self, D> {
|
) -> FloatTensor<Self, D> {
|
||||||
kernel::clamp(tensor, min, max)
|
kernel::clamp(tensor, min, max)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn recip<const D: usize>(
|
||||||
|
tensor: FloatTensor<Wgpu<G, F, I>, D>,
|
||||||
|
) -> FloatTensor<Wgpu<G, F, I>, D> {
|
||||||
|
unary!(Recip, func "1.0 /");
|
||||||
|
unary_inplace!(RecipInplace, func "1.0 /");
|
||||||
|
|
||||||
|
if tensor.can_mut() {
|
||||||
|
return unary_inplace_default::<RecipInplace, F, D>(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_default::<Recip, F, D>(tensor)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue