mirror of https://github.com/tracel-ai/burn.git
Add is_close and all_close tensor operators (#1389)
* Add is_close and all_close tensor operators * Fix broken build issues * Fix the table * Add tests to candle
This commit is contained in:
parent
201e7f87c9
commit
d43a0b3f90
|
@ -13,24 +13,27 @@ Tensor<B, D, Bool> // Bool tensor
|
|||
Note that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by
|
||||
backend implementations.
|
||||
|
||||
Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The
|
||||
actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as
|
||||
below:
|
||||
Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape.
|
||||
The actual shape of the tensor is inferred from its initialization. For example, a Tensor of size
|
||||
(5,) is initialized as below:
|
||||
|
||||
```rust, ignore
|
||||
// correct: Tensor is 1-Dimensional with 5 elements
|
||||
let tensor_1 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
let floats = [1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
|
||||
// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
// correct: Tensor is 1-Dimensional with 5 elements
|
||||
let tensor_1 = Tensor::<Backend, 1>::from_floats(floats);
|
||||
|
||||
// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats(floats);
|
||||
// this will lead to an error and is for creating a 5-D tensor
|
||||
```
|
||||
|
||||
### Initialization
|
||||
|
||||
Burn Tensors are primarily initialized using the `from_data()` method which takes the `Data` struct as input.
|
||||
The `Data` struct has two fields: value & shape. To retrieve the data from a tensor, the method `.to_data()` should be
|
||||
employed when intending to reuse the tensor afterward. Alternatively, `.into_data()` is recommended for one-time use.
|
||||
Let's look at a couple of examples for initializing a tensor from different inputs.
|
||||
Burn Tensors are primarily initialized using the `from_data()` method which takes the `Data` struct
|
||||
as input. The `Data` struct has two fields: value & shape. To retrieve the data from a tensor, the
|
||||
method `.to_data()` should be employed when intending to reuse the tensor afterward. Alternatively,
|
||||
`.into_data()` is recommended for one-time use. Let's look at a couple of examples for initializing
|
||||
a tensor from different inputs.
|
||||
|
||||
```rust, ignore
|
||||
|
||||
|
@ -41,7 +44,8 @@ let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0]);
|
|||
let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert());
|
||||
|
||||
// Initialization using from_floats (Recommended for f32 ElementType)
|
||||
// Will be converted to Data internally. `.convert()` not needed as from_floats() defined for fixed ElementType
|
||||
// Will be converted to Data internally.
|
||||
// `.convert()` not needed as from_floats() defined for fixed ElementType
|
||||
let tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0]);
|
||||
|
||||
// Initialization of Int Tensor from array slices
|
||||
|
@ -61,54 +65,58 @@ let bmi = BodyMetrics{
|
|||
height: 180,
|
||||
weight: 80.0
|
||||
};
|
||||
let tensor_5 = Tensor::<Backend, 1>::from_data(Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert());
|
||||
let data = Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert();
|
||||
let tensor_5 = Tensor::<Backend, 1>::from_data(data);
|
||||
|
||||
```
|
||||
|
||||
The `.convert()` method for Data struct is called to ensure that the data's primitive type is
|
||||
consistent across all backends. With `.from_floats()` method the ElementType is fixed as f32
|
||||
and therefore no convert operation is required across backends. This operation can also be done at element wise
|
||||
level as:
|
||||
`let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()])`. The `ElementConversion` trait
|
||||
however needs to be imported for the element wise operation.
|
||||
consistent across all backends. With `.from_floats()` method the ElementType is fixed as f32 and
|
||||
therefore no convert operation is required across backends. This operation can also be done at
|
||||
element wise level as:
|
||||
`let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()])`. The
|
||||
`ElementConversion` trait however needs to be imported for the element wise operation.
|
||||
|
||||
## Ownership and Cloning
|
||||
|
||||
Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple
|
||||
times will necessitate cloning it. Let's look at an example to understand the ownership rules and cloning better.
|
||||
Suppose we want to do a simple min-max normalization of an input tensor.
|
||||
times will necessitate cloning it. Let's look at an example to understand the ownership rules and
|
||||
cloning better. Suppose we want to do a simple min-max normalization of an input tensor.
|
||||
|
||||
```rust, ignore
|
||||
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
|
||||
let min = input.min();
|
||||
let max = input.max();
|
||||
let input = (input - min).div(max - min);
|
||||
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
|
||||
let min = input.min();
|
||||
let max = input.max();
|
||||
let input = (input - min).div(max - min);
|
||||
```
|
||||
|
||||
With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules will give an error
|
||||
and prevent using the input tensor after the first `.min()` operation. The ownership of the input tensor is transferred
|
||||
to the variable `min` and the input tensor is no longer available for further operations. Burn Tensors like most
|
||||
complex primitives do not implement the `Copy` trait and therefore have to be cloned explicitly. Now let's rewrite
|
||||
a working example of doing min-max normalization with cloning.
|
||||
With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules
|
||||
will give an error and prevent using the input tensor after the first `.min()` operation. The
|
||||
ownership of the input tensor is transferred to the variable `min` and the input tensor is no longer
|
||||
available for further operations. Burn Tensors like most complex primitives do not implement the
|
||||
`Copy` trait and therefore have to be cloned explicitly. Now let's rewrite a working example of
|
||||
doing min-max normalization with cloning.
|
||||
|
||||
```rust, ignore
|
||||
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
|
||||
let min = input.clone().min();
|
||||
let max = input.clone().max();
|
||||
let input = (input.clone() - min.clone()).div(max - min);
|
||||
println!("{:?}", input.to_data()); // Success: [0.0, 0.33333334, 0.6666667, 1.0]
|
||||
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
|
||||
let min = input.clone().min();
|
||||
let max = input.clone().max();
|
||||
let input = (input.clone() - min.clone()).div(max - min);
|
||||
println!("{:?}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]
|
||||
|
||||
// Notice that max, min have been moved in last operation so the below print will give an error.
|
||||
// If we want to use them for further operations, they will need to be cloned in similar fashion.
|
||||
// println!("{:?}", min.to_data());
|
||||
// Notice that max, min have been moved in last operation so
|
||||
// the below print will give an error.
|
||||
// If we want to use them for further operations,
|
||||
// they will need to be cloned in similar fashion.
|
||||
// println!("{:?}", min.to_data());
|
||||
```
|
||||
|
||||
We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't copied,
|
||||
and only a reference to it is increased. This makes it possible to determine exactly how many times a tensor is used,
|
||||
which is very convenient for reusing tensor buffers or even fusing operations into a single
|
||||
kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)).
|
||||
For that reason, we don't provide explicit inplace operations. If a tensor is used only one time, inplace operations
|
||||
will always be used when available.
|
||||
We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't
|
||||
copied, and only a reference to it is increased. This makes it possible to determine exactly how
|
||||
many times a tensor is used, which is very convenient for reusing tensor buffers or even fusing
|
||||
operations into a single kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)). For
|
||||
that reason, we don't provide explicit inplace operations. If a tensor is used only one time,
|
||||
inplace operations will always be used when available.
|
||||
|
||||
## Tensor Operations
|
||||
|
||||
|
@ -127,145 +135,146 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
|
|||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
|---------------------------------------|--------------------------------------|
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
||||
| `tensor.equal(other)` | `x == y` |
|
||||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.any()` | `tensor.any()` |
|
||||
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
||||
| `tensor.all()` | `tensor.all()` |
|
||||
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
||||
| ------------------------------------- | ------------------------------------ |
|
||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.to_data()` | N/A |
|
||||
| `Tensor::from_data(data, device)` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
|
||||
| `tensor.all()` | `tensor.all()` |
|
||||
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
||||
| `tensor.any()` | `tensor.any()` |
|
||||
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.equal(other)` | `x == y` |
|
||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||
| `tensor.to_data()` | N/A |
|
||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
||||
|
||||
### Numeric Operations
|
||||
|
||||
Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
|------------------------------------------------------------------|------------------------------------------------|
|
||||
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
|
||||
| `tensor + other` or `tensor.add(other)` | `tensor + other` |
|
||||
| `tensor + scalar` or `tensor.add_scalar(scalar)` | `tensor + scalar` |
|
||||
| `tensor - other` or `tensor.sub(other)` | `tensor - other` |
|
||||
| `tensor - scalar` or `tensor.sub_scalar(scalar)` | `tensor - scalar` |
|
||||
| `tensor / other` or `tensor.div(other)` | `tensor / other` |
|
||||
| `tensor / scalar` or `tensor.div_scalar(scalar)` | `tensor / scalar` |
|
||||
| `tensor * other` or `tensor.mul(other)` | `tensor * other` |
|
||||
| `tensor * scalar` or `tensor.mul_scalar(scalar)` | `tensor * scalar` |
|
||||
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
|
||||
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
|
||||
| `-tensor` or `tensor.neg()` | `-tensor` |
|
||||
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
|
||||
| `Tensor::zeros(shape, device)` | `torch.zeros(shape, device=device)` |
|
||||
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
|
||||
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
|
||||
| `tensor.mean()` | `tensor.mean()` |
|
||||
| `tensor.sum()` | `tensor.sum()` |
|
||||
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
|
||||
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
|
||||
| `tensor.equal_elem(scalar)` | `tensor.eq(scalar)` |
|
||||
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
|
||||
| `tensor.greater(other)` | `tensor.gt(other)` |
|
||||
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
|
||||
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
|
||||
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
|
||||
| `tensor.lower(other)` | `tensor.lt(other)` |
|
||||
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
|
||||
| `tensor.lower_equal(other)` | `tensor.le(other)` |
|
||||
| `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` |
|
||||
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
|
||||
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
|
||||
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
|
||||
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
|
||||
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
|
||||
| `tensor.select_assign(dim, indices, values)` | N/A |
|
||||
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
|
||||
| `tensor.max()` | `tensor.max()` |
|
||||
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
|
||||
| `tensor.max_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
|
||||
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
|
||||
| `tensor.min()` | `tensor.min()` |
|
||||
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
|
||||
| `tensor.min_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
|
||||
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
|
||||
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
|
||||
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
|
||||
| `tensor.abs()` | `torch.abs(tensor)` |
|
||||
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
|
||||
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
|
||||
| Burn | PyTorch Equivalent |
|
||||
| --------------------------------------------------------------- | ---------------------------------------------- |
|
||||
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
|
||||
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
|
||||
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
|
||||
| `Tensor::zeros(shape, device)` | `torch.zeros(shape, device=device)` |
|
||||
| `tensor.abs()` | `torch.abs(tensor)` |
|
||||
| `tensor.add(other)` or `tensor + other` | `tensor + other` |
|
||||
| `tensor.add_scalar(scalar)` or `tensor + scalar` | `tensor + scalar` |
|
||||
| `tensor.all_close(other, atol, rtol)` | `torch.allclose(tensor, other, atol, rtol)` |
|
||||
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
|
||||
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
|
||||
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
|
||||
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
|
||||
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
|
||||
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
|
||||
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
|
||||
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
|
||||
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
|
||||
| `tensor.greater(other)` | `tensor.gt(other)` |
|
||||
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
|
||||
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
|
||||
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
|
||||
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
|
||||
| `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` |
|
||||
| `tensor.lower(other)` | `tensor.lt(other)` |
|
||||
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
|
||||
| `tensor.lower_equal(other)` | `tensor.le(other)` |
|
||||
| `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` |
|
||||
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
|
||||
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
|
||||
| `tensor.max()` | `tensor.max()` |
|
||||
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
|
||||
| `tensor.max_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
|
||||
| `tensor.mean()` | `tensor.mean()` |
|
||||
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
|
||||
| `tensor.min()` | `tensor.min()` |
|
||||
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
|
||||
| `tensor.min_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
|
||||
| `tensor.mul(other)` or `tensor * other` | `tensor * other` |
|
||||
| `tensor.mul_scalar(scalar)` or `tensor * scalar` | `tensor * scalar` |
|
||||
| `tensor.neg()` or `-tensor` | `-tensor` |
|
||||
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
|
||||
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
|
||||
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
|
||||
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
|
||||
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
|
||||
| `tensor.select_assign(dim, indices, values)` | N/A |
|
||||
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
|
||||
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
|
||||
| `tensor.sum()` | `tensor.sum()` |
|
||||
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
|
||||
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
|
||||
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
|
||||
|
||||
### Float Operations
|
||||
|
||||
Those operations are only available for `Float` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------------------------------------|------------------------------------|
|
||||
| -------------------------------------------- | ---------------------------------- |
|
||||
| `tensor.cos()` | `tensor.cos()` |
|
||||
| `tensor.erf()` | `tensor.erf()` |
|
||||
| `tensor.exp()` | `tensor.exp()` |
|
||||
| `tensor.from_floats(floats, device)` | N/A |
|
||||
| `tensor.from_full_precision(tensor)` | N/A |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||
| `tensor.log()` | `tensor.log()` |
|
||||
| `tensor.log1p()` | `tensor.log1p()` |
|
||||
| `tensor.erf()` | `tensor.erf()` |
|
||||
| `tensor.sqrt()` | `tensor.sqrt()` |
|
||||
| `tensor.recip()` | `tensor.reciprocal()` |
|
||||
| `tensor.cos()` | `tensor.cos()` |
|
||||
| `tensor.sin()` | `tensor.sin()` |
|
||||
| `tensor.tanh()` | `tensor.tanh()` |
|
||||
| `tensor.from_floats(floats, device)` | N/A |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
|
||||
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
|
||||
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
|
||||
| `tensor.one_hot(index, num_classes, device)` | N/A |
|
||||
| `tensor.transpose()` | `tensor.T` |
|
||||
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
|
||||
| `tensor.matmul(other)` | `tensor.matmul(other)` |
|
||||
| `tensor.one_hot(index, num_classes, device)` | N/A |
|
||||
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
|
||||
| `tensor.random(shape, distribution, device)` | N/A |
|
||||
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
|
||||
| `tensor.recip()` | `tensor.reciprocal()` |
|
||||
| `tensor.sin()` | `tensor.sin()` |
|
||||
| `tensor.sqrt()` | `tensor.sqrt()` |
|
||||
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
|
||||
| `tensor.tanh()` | `tensor.tanh()` |
|
||||
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
|
||||
| `tensor.transpose()` | `tensor.T` |
|
||||
| `tensor.var(dim)` | `tensor.var(dim)` |
|
||||
| `tensor.var_bias(dim)` | N/A |
|
||||
| `tensor.var_mean(dim)` | N/A |
|
||||
| `tensor.var_mean_bias(dim)` | N/A |
|
||||
| `tensor.random(shape, distribution, device)` | N/A |
|
||||
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
|
||||
| `tensor.from_full_precision(tensor)` | N/A |
|
||||
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
|
||||
|
||||
# Int Operations
|
||||
|
||||
Those operations are only available for `Int` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------------------------------|---------------------------------------------------------|
|
||||
| `tensor.from_ints(ints)` | N/A |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
|
||||
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
|
||||
| `tensor.int_random(shape, distribution, device)` | N/A |
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------------------------------------ | ------------------------------------------------------- |
|
||||
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
|
||||
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.from_ints(ints)` | N/A |
|
||||
| `tensor.int_random(shape, distribution, device)` | N/A |
|
||||
|
||||
# Bool Operations
|
||||
|
||||
Those operations are only available for `Bool` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|------------------|-------------------------------------|
|
||||
| ---------------- | ----------------------------------- |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||
| `tensor.not()` | `tensor.logical_not()` |
|
||||
|
@ -273,11 +282,12 @@ Those operations are only available for `Bool` tensors.
|
|||
## Activation Functions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|------------------------------------------|-------------------------------------------------------|
|
||||
| ---------------------------------------- | ----------------------------------------------------- |
|
||||
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
|
||||
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
|
||||
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
|
||||
| `activation::mish(tensor)` | Similar to `nn.functional.mish(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
|
||||
| `activation::quiet_softmax(tensor, dim)` | Similar to `nn.functional.quiet_softmax(tensor, dim)` |
|
||||
| `activation::relu(tensor)` | Similar to `nn.functional.relu(tensor)` |
|
||||
| `activation::sigmoid(tensor)` | Similar to `nn.functional.sigmoid(tensor)` |
|
||||
|
@ -285,4 +295,3 @@ Those operations are only available for `Bool` tensors.
|
|||
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
|
||||
|
|
|
@ -62,6 +62,7 @@ mod tests {
|
|||
burn_tensor::testgen_recip!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_close!();
|
||||
// burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
|
|
|
@ -224,7 +224,7 @@ impl<B: Backend> Lstm<B> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::{Data, Distribution, Shape};
|
||||
use burn_tensor::{Data, Distribution};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::TestAutodiffBackend;
|
||||
|
@ -360,6 +360,7 @@ mod tests {
|
|||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_batched_backward_pass() {
|
||||
use burn_tensor::Shape;
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 32, true).init(&device);
|
||||
let shape: Shape<3> = [8, 10, 64].into();
|
||||
|
|
|
@ -11,7 +11,10 @@ use alloc::vec;
|
|||
|
||||
use burn_common::{reader::Reader, stub::Mutex};
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use serde::{Serialize, Serializer};
|
||||
|
||||
use crate::check::TensorCheck;
|
||||
use crate::tensor::api::chunk::chunk;
|
||||
|
@ -617,7 +620,6 @@ where
|
|||
///
|
||||
/// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
|
||||
/// evaluates to True, False otherwise.
|
||||
|
||||
pub fn any(self) -> Tensor<B, 1, Bool> {
|
||||
K::any(self.primitive)
|
||||
}
|
||||
|
@ -664,10 +666,33 @@ where
|
|||
/// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
|
||||
/// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
|
||||
/// evaluates to True, False otherwise.
|
||||
|
||||
pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
|
||||
K::all_dim(self.primitive, dim)
|
||||
}
|
||||
|
||||
/// Convert the tensor into a scalar.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn into_scalar(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let data = self.into_data();
|
||||
data.value[0]
|
||||
}
|
||||
|
||||
/// Convert the tensor into a scalar.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn into_scalar(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let data = self.into_data().await;
|
||||
data.value[0]
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator given by (Tensor::iter_dim).
|
||||
|
@ -968,7 +993,7 @@ impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
|
|||
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
|
||||
pub trait BasicOps<B: Backend>: TensorKind<B> {
|
||||
/// The type of the tensor elements.
|
||||
type Elem: 'static;
|
||||
type Elem: 'static + Copy;
|
||||
|
||||
/// Creates an empty tensor with the given shape.
|
||||
///
|
||||
|
|
|
@ -9,30 +9,6 @@ where
|
|||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
/// Convert the tensor into a scalar.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
pub fn into_scalar(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let data = self.into_data();
|
||||
data.value[0]
|
||||
}
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
/// Convert the tensor into a scalar.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
pub async fn into_scalar(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let data = self.into_data().await;
|
||||
data.value[0]
|
||||
}
|
||||
|
||||
/// Applies element wise addition operation.
|
||||
///
|
||||
/// `y = x2 + x1`
|
||||
|
@ -603,6 +579,67 @@ where
|
|||
pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
|
||||
Self::new(K::powi_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Checks element wise if the tensor is close to another tensor.
|
||||
///
|
||||
/// The tolerance is defined by the following equation:
|
||||
///
|
||||
/// ```text
|
||||
/// abs(a - b) <= (atol + rtol * abs(b))
|
||||
///
|
||||
/// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
|
||||
/// and `atol` is the absolute tolerance.
|
||||
/// ```
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `other` - The tensor to compare with.
|
||||
/// * `rtol` - Optional relative tolerance. Default is 1e-5.
|
||||
/// * `atol` - Optional absolute tolerance. Default is 1e-8.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensors.
|
||||
pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
|
||||
let rtol = rtol.unwrap_or(1e-5);
|
||||
let atol = atol.unwrap_or(1e-8);
|
||||
|
||||
K::lower_equal(
|
||||
K::abs(K::sub(self.primitive, other.primitive.clone())),
|
||||
K::add_scalar(K::mul_scalar(K::abs(other.primitive), rtol), atol),
|
||||
)
|
||||
}
|
||||
|
||||
/// Checks if all elements are close to another tensor.
|
||||
///
|
||||
/// The tolerance is defined by the following equation:
|
||||
///
|
||||
/// ```text
|
||||
///
|
||||
/// abs(a - b) <= (atol + rtol * abs(b))
|
||||
///
|
||||
/// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
|
||||
/// and `atol` is the absolute tolerance.
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `other` - The tensor to compare with.
|
||||
/// * `rtol` - Optional relative tolerance. Default is 1e-5.
|
||||
/// * `atol` - Optional absolute tolerance. Default is 1e-8.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean scalar.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This method is only available for non-wasm targets or when the `wasm-sync` feature is enabled.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
|
||||
self.is_close(other, rtol, atol).all().into_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, K> Tensor<B, 2, K>
|
||||
|
|
|
@ -42,6 +42,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_chunk!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_close!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_create_like!();
|
||||
burn_tensor::testgen_div!();
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#[burn_tensor_testgen::testgen(close)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_is_close() {
|
||||
let tensor1 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;
|
||||
let data_actual = tensor1.is_close(tensor2, None, None).into_data();
|
||||
let data_expected = Data::from([[true, true, true], [true, true, false]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_close() {
|
||||
let tensor1 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;
|
||||
assert!(!tensor1.clone().all_close(tensor2.clone(), None, None));
|
||||
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-9;
|
||||
assert!(tensor1.all_close(tensor2, None, None));
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ mod cast;
|
|||
mod cat;
|
||||
mod chunk;
|
||||
mod clamp;
|
||||
mod close;
|
||||
mod cos;
|
||||
mod create_like;
|
||||
mod div;
|
||||
|
|
|
@ -249,6 +249,7 @@ impl CompilationSettings {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn is_contiguous(strides: &[usize]) -> bool {
|
||||
let mut current = 0;
|
||||
|
||||
|
@ -271,6 +272,7 @@ pub enum InputInfo {
|
|||
|
||||
impl InputInfo {
|
||||
/// The item type of the input.
|
||||
#[allow(dead_code)]
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
InputInfo::Array {
|
||||
|
@ -284,6 +286,7 @@ impl InputInfo {
|
|||
|
||||
impl OutputInfo {
|
||||
/// The item type of the input.
|
||||
#[allow(dead_code)]
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
OutputInfo::ArrayWrite { item, local: _ } => *item,
|
||||
|
@ -314,6 +317,7 @@ pub enum OutputInfo {
|
|||
}
|
||||
|
||||
impl OutputInfo {
|
||||
#[allow(dead_code)]
|
||||
pub fn elem_size<R: Runtime>(&self) -> usize {
|
||||
let elem = match self {
|
||||
OutputInfo::ArrayWrite { item, local: _ } => bool_elem(item.elem()),
|
||||
|
|
|
@ -170,6 +170,7 @@ impl Scope {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> {
|
||||
self.reads_global
|
||||
.iter()
|
||||
|
|
Loading…
Reference in New Issue