Add is_close and all_close tensor operators (#1389)

* Add is_close and all_close tensor operators

* Fix broken build issues

* Fix the table

* Add tests to candle
This commit is contained in:
Dilshod Tadjibaev 2024-03-01 15:37:14 -06:00 committed by GitHub
parent 201e7f87c9
commit d43a0b3f90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 283 additions and 179 deletions

View File

@ -13,24 +13,27 @@ Tensor<B, D, Bool> // Bool tensor
Note that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by
backend implementations.
Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The
actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as
below:
Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape.
The actual shape of the tensor is inferred from its initialization. For example, a Tensor of size
(5,) is initialized as below:
```rust, ignore
// correct: Tensor is 1-Dimensional with 5 elements
let tensor_1 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0]);
let floats = [1.0, 2.0, 3.0, 4.0, 5.0];
// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0]);
// correct: Tensor is 1-Dimensional with 5 elements
let tensor_1 = Tensor::<Backend, 1>::from_floats(floats);
// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats(floats);
// this will lead to an error and is for creating a 5-D tensor
```
### Initialization
Burn Tensors are primarily initialized using the `from_data()` method which takes the `Data` struct as input.
The `Data` struct has two fields: value & shape. To retrieve the data from a tensor, the method `.to_data()` should be
employed when intending to reuse the tensor afterward. Alternatively, `.into_data()` is recommended for one-time use.
Let's look at a couple of examples for initializing a tensor from different inputs.
Burn Tensors are primarily initialized using the `from_data()` method which takes the `Data` struct
as input. The `Data` struct has two fields: value & shape. To retrieve the data from a tensor, the
method `.to_data()` should be employed when intending to reuse the tensor afterward. Alternatively,
`.into_data()` is recommended for one-time use. Let's look at a couple of examples for initializing
a tensor from different inputs.
```rust, ignore
@ -41,7 +44,8 @@ let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0]);
let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert());
// Initialization using from_floats (Recommended for f32 ElementType)
// Will be converted to Data internally. `.convert()` not needed as from_floats() defined for fixed ElementType
// Will be converted to Data internally.
// `.convert()` not needed as from_floats() defined for fixed ElementType
let tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0]);
// Initialization of Int Tensor from array slices
@ -61,54 +65,58 @@ let bmi = BodyMetrics{
height: 180,
weight: 80.0
};
let tensor_5 = Tensor::<Backend, 1>::from_data(Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert());
let data = Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert();
let tensor_5 = Tensor::<Backend, 1>::from_data(data);
```
The `.convert()` method for Data struct is called to ensure that the data's primitive type is
consistent across all backends. With `.from_floats()` method the ElementType is fixed as f32
and therefore no convert operation is required across backends. This operation can also be done at element wise
level as:
`let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()])`. The `ElementConversion` trait
however needs to be imported for the element wise operation.
consistent across all backends. With `.from_floats()` method the ElementType is fixed as f32 and
therefore no convert operation is required across backends. This operation can also be done at
element wise level as:
`let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()])`. The
`ElementConversion` trait however needs to be imported for the element wise operation.
## Ownership and Cloning
Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple
times will necessitate cloning it. Let's look at an example to understand the ownership rules and cloning better.
Suppose we want to do a simple min-max normalization of an input tensor.
times will necessitate cloning it. Let's look at an example to understand the ownership rules and
cloning better. Suppose we want to do a simple min-max normalization of an input tensor.
```rust, ignore
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
let min = input.min();
let max = input.max();
let input = (input - min).div(max - min);
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
let min = input.min();
let max = input.max();
let input = (input - min).div(max - min);
```
With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules will give an error
and prevent using the input tensor after the first `.min()` operation. The ownership of the input tensor is transferred
to the variable `min` and the input tensor is no longer available for further operations. Burn Tensors like most
complex primitives do not implement the `Copy` trait and therefore have to be cloned explicitly. Now let's rewrite
a working example of doing min-max normalization with cloning.
With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules
will give an error and prevent using the input tensor after the first `.min()` operation. The
ownership of the input tensor is transferred to the variable `min` and the input tensor is no longer
available for further operations. Burn Tensors like most complex primitives do not implement the
`Copy` trait and therefore have to be cloned explicitly. Now let's rewrite a working example of
doing min-max normalization with cloning.
```rust, ignore
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
let min = input.clone().min();
let max = input.clone().max();
let input = (input.clone() - min.clone()).div(max - min);
println!("{:?}", input.to_data()); // Success: [0.0, 0.33333334, 0.6666667, 1.0]
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
let min = input.clone().min();
let max = input.clone().max();
let input = (input.clone() - min.clone()).div(max - min);
println!("{:?}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]
// Notice that max, min have been moved in last operation so the below print will give an error.
// If we want to use them for further operations, they will need to be cloned in similar fashion.
// println!("{:?}", min.to_data());
// Notice that max, min have been moved in last operation so
// the below print will give an error.
// If we want to use them for further operations,
// they will need to be cloned in similar fashion.
// println!("{:?}", min.to_data());
```
We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't copied,
and only a reference to it is increased. This makes it possible to determine exactly how many times a tensor is used,
which is very convenient for reusing tensor buffers or even fusing operations into a single
kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)).
For that reason, we don't provide explicit inplace operations. If a tensor is used only one time, inplace operations
will always be used when available.
We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't
copied, and only a reference to it is increased. This makes it possible to determine exactly how
many times a tensor is used, which is very convenient for reusing tensor buffers or even fusing
operations into a single kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)). For
that reason, we don't provide explicit inplace operations. If a tensor is used only one time,
inplace operations will always be used when available.
## Tensor Operations
@ -127,145 +135,146 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| Burn | PyTorch Equivalent |
|---------------------------------------|--------------------------------------|
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.equal(other)` | `x == y` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| ------------------------------------- | ------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `tensor.into_data()` | N/A |
| `tensor.to_data()` | N/A |
| `Tensor::from_data(data, device)` | N/A |
| `tensor.into_primitive()` | N/A |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
### Numeric Operations
Those operations are available for numeric tensor kinds: `Float` and `Int`.
| Burn | PyTorch Equivalent |
|------------------------------------------------------------------|------------------------------------------------|
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor + other` or `tensor.add(other)` | `tensor + other` |
| `tensor + scalar` or `tensor.add_scalar(scalar)` | `tensor + scalar` |
| `tensor - other` or `tensor.sub(other)` | `tensor - other` |
| `tensor - scalar` or `tensor.sub_scalar(scalar)` | `tensor - scalar` |
| `tensor / other` or `tensor.div(other)` | `tensor / other` |
| `tensor / scalar` or `tensor.div_scalar(scalar)` | `tensor / scalar` |
| `tensor * other` or `tensor.mul(other)` | `tensor * other` |
| `tensor * scalar` or `tensor.mul_scalar(scalar)` | `tensor * scalar` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
| `-tensor` or `tensor.neg()` | `-tensor` |
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
| `Tensor::zeros(shape, device)` | `torch.zeros(shape, device=device)` |
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
| `tensor.mean()` | `tensor.mean()` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
| `tensor.equal_elem(scalar)` | `tensor.eq(scalar)` |
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.greater(other)` | `tensor.gt(other)` |
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
| `tensor.lower(other)` | `tensor.lt(other)` |
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
| `tensor.lower_equal(other)` | `tensor.le(other)` |
| `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` |
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
| `tensor.max()` | `tensor.max()` |
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
| `tensor.max_dim_with_indices(dim)` | N/A |
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
| `tensor.min()` | `tensor.min()` |
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
| `tensor.min_dim_with_indices(dim)` | N/A |
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.abs()` | `torch.abs(tensor)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| Burn | PyTorch Equivalent |
| --------------------------------------------------------------- | ---------------------------------------------- |
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
| `Tensor::zeros(shape, device)` | `torch.zeros(shape, device=device)` |
| `tensor.abs()` | `torch.abs(tensor)` |
| `tensor.add(other)` or `tensor + other` | `tensor + other` |
| `tensor.add_scalar(scalar)` or `tensor + scalar` | `tensor + scalar` |
| `tensor.all_close(other, atol, rtol)` | `torch.allclose(tensor, other, atol, rtol)` |
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
| `tensor.greater(other)` | `tensor.gt(other)` |
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` |
| `tensor.lower(other)` | `tensor.lt(other)` |
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
| `tensor.lower_equal(other)` | `tensor.le(other)` |
| `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` |
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
| `tensor.max()` | `tensor.max()` |
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
| `tensor.max_dim_with_indices(dim)` | N/A |
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
| `tensor.mean()` | `tensor.mean()` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
| `tensor.min()` | `tensor.min()` |
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
| `tensor.min_dim_with_indices(dim)` | N/A |
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
| `tensor.mul(other)` or `tensor * other` | `tensor * other` |
| `tensor.mul_scalar(scalar)` or `tensor * scalar` | `tensor * scalar` |
| `tensor.neg()` or `-tensor` | `-tensor` |
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
### Float Operations
Those operations are only available for `Float` tensors.
| Burn API | PyTorch Equivalent |
|----------------------------------------------|------------------------------------|
| -------------------------------------------- | ---------------------------------- |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.exp()` | `tensor.exp()` |
| `tensor.from_floats(floats, device)` | N/A |
| `tensor.from_full_precision(tensor)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.log()` | `tensor.log()` |
| `tensor.log1p()` | `tensor.log1p()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.from_floats(floats, device)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.one_hot(index, num_classes, device)` | N/A |
| `tensor.transpose()` | `tensor.T` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.matmul(other)` | `tensor.matmul(other)` |
| `tensor.one_hot(index, num_classes, device)` | N/A |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
| `tensor.random(shape, distribution, device)` | N/A |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
| `tensor.var_mean_bias(dim)` | N/A |
| `tensor.random(shape, distribution, device)` | N/A |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.from_full_precision(tensor)` | N/A |
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
# Int Operations
Those operations are only available for `Int` tensors.
| Burn API | PyTorch Equivalent |
|----------------------------------------|---------------------------------------------------------|
| `tensor.from_ints(ints)` | N/A |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.int_random(shape, distribution, device)` | N/A |
| Burn API | PyTorch Equivalent |
| ------------------------------------------------ | ------------------------------------------------------- |
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
# Bool Operations
Those operations are only available for `Bool` tensors.
| Burn API | PyTorch Equivalent |
|------------------|-------------------------------------|
| ---------------- | ----------------------------------- |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
@ -273,11 +282,12 @@ Those operations are only available for `Bool` tensors.
## Activation Functions
| Burn API | PyTorch Equivalent |
|------------------------------------------|-------------------------------------------------------|
| ---------------------------------------- | ----------------------------------------------------- |
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | Similar to `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | Similar to `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | Similar to `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | Similar to `nn.functional.sigmoid(tensor)` |
@ -285,4 +295,3 @@ Those operations are only available for `Bool` tensors.
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |

View File

@ -62,6 +62,7 @@ mod tests {
burn_tensor::testgen_recip!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_close!();
// burn_tensor::testgen_div!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();

View File

@ -224,7 +224,7 @@ impl<B: Backend> Lstm<B> {
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::{Data, Distribution, Shape};
use burn_tensor::{Data, Distribution};
#[cfg(feature = "std")]
use crate::TestAutodiffBackend;
@ -360,6 +360,7 @@ mod tests {
#[test]
#[cfg(feature = "std")]
fn test_batched_backward_pass() {
use burn_tensor::Shape;
let device = Default::default();
let lstm = LstmConfig::new(64, 32, true).init(&device);
let shape: Shape<3> = [8, 10, 64].into();

View File

@ -11,7 +11,10 @@ use alloc::vec;
use burn_common::{reader::Reader, stub::Mutex};
use core::{fmt::Debug, ops::Range};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Deserializer};
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use serde::{Serialize, Serializer};
use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
@ -617,7 +620,6 @@ where
///
/// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
/// evaluates to True, False otherwise.
pub fn any(self) -> Tensor<B, 1, Bool> {
K::any(self.primitive)
}
@ -664,10 +666,33 @@ where
/// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
/// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
/// evaluates to True, False otherwise.
pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
K::all_dim(self.primitive, dim)
}
/// Convert the tensor into a scalar.
///
/// # Panics
///
/// If the tensor doesn't have one element.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn into_scalar(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let data = self.into_data();
data.value[0]
}
/// Convert the tensor into a scalar.
///
/// # Panics
///
/// If the tensor doesn't have one element.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn into_scalar(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let data = self.into_data().await;
data.value[0]
}
}
/// Iterator given by (Tensor::iter_dim).
@ -968,7 +993,7 @@ impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The type of the tensor elements.
type Elem: 'static;
type Elem: 'static + Copy;
/// Creates an empty tensor with the given shape.
///

View File

@ -9,30 +9,6 @@ where
K: Numeric<B>,
K::Elem: Element,
{
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Convert the tensor into a scalar.
///
/// # Panics
///
/// If the tensor doesn't have one element.
pub fn into_scalar(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let data = self.into_data();
data.value[0]
}
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Convert the tensor into a scalar.
///
/// # Panics
///
/// If the tensor doesn't have one element.
pub async fn into_scalar(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let data = self.into_data().await;
data.value[0]
}
/// Applies element wise addition operation.
///
/// `y = x2 + x1`
@ -603,6 +579,67 @@ where
pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
Self::new(K::powi_scalar(self.primitive, other))
}
/// Checks element wise if the tensor is close to another tensor.
///
/// The tolerance is defined by the following equation:
///
/// ```text
/// abs(a - b) <= (atol + rtol * abs(b))
///
/// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
/// and `atol` is the absolute tolerance.
/// ```
///
/// # Arguments
///
/// * `other` - The tensor to compare with.
/// * `rtol` - Optional relative tolerance. Default is 1e-5.
/// * `atol` - Optional absolute tolerance. Default is 1e-8.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors.
pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
let rtol = rtol.unwrap_or(1e-5);
let atol = atol.unwrap_or(1e-8);
K::lower_equal(
K::abs(K::sub(self.primitive, other.primitive.clone())),
K::add_scalar(K::mul_scalar(K::abs(other.primitive), rtol), atol),
)
}
/// Checks if all elements are close to another tensor.
///
/// The tolerance is defined by the following equation:
///
/// ```text
///
/// abs(a - b) <= (atol + rtol * abs(b))
///
/// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
/// and `atol` is the absolute tolerance.
///
/// ```
///
/// # Arguments
///
/// * `other` - The tensor to compare with.
/// * `rtol` - Optional relative tolerance. Default is 1e-5.
/// * `atol` - Optional absolute tolerance. Default is 1e-8.
///
/// # Returns
///
/// A boolean scalar.
///
/// # Remarks
///
/// This method is only available for non-wasm targets or when the `wasm-sync` feature is enabled.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
self.is_close(other, rtol, atol).all().into_scalar()
}
}
impl<B, K> Tensor<B, 2, K>

View File

@ -42,6 +42,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_cat!();
burn_tensor::testgen_chunk!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_close!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_create_like!();
burn_tensor::testgen_div!();

View File

@ -0,0 +1,24 @@
#[burn_tensor_testgen::testgen(close)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn test_is_close() {
let tensor1 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;
let data_actual = tensor1.is_close(tensor2, None, None).into_data();
let data_expected = Data::from([[true, true, true], [true, true, false]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn test_all_close() {
let tensor1 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;
assert!(!tensor1.clone().all_close(tensor2.clone(), None, None));
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-9;
assert!(tensor1.all_close(tensor2, None, None));
}
}

View File

@ -10,6 +10,7 @@ mod cast;
mod cat;
mod chunk;
mod clamp;
mod close;
mod cos;
mod create_like;
mod div;

View File

@ -249,6 +249,7 @@ impl CompilationSettings {
}
}
#[allow(dead_code)]
fn is_contiguous(strides: &[usize]) -> bool {
let mut current = 0;
@ -271,6 +272,7 @@ pub enum InputInfo {
impl InputInfo {
/// The item type of the input.
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
InputInfo::Array {
@ -284,6 +286,7 @@ impl InputInfo {
impl OutputInfo {
/// The item type of the input.
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
OutputInfo::ArrayWrite { item, local: _ } => *item,
@ -314,6 +317,7 @@ pub enum OutputInfo {
}
impl OutputInfo {
#[allow(dead_code)]
pub fn elem_size<R: Runtime>(&self) -> usize {
let elem = match self {
OutputInfo::ArrayWrite { item, local: _ } => bool_elem(item.elem()),

View File

@ -170,6 +170,7 @@ impl Scope {
}
}
#[allow(dead_code)]
pub(crate) fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> {
self.reads_global
.iter()