minor fixes

This commit is contained in:
Aasheesh Singh 2024-01-30 10:23:33 -05:00
parent 61dd2c17b9
commit 4ebab1a538
2 changed files with 4 additions and 4 deletions

View File

@ -286,7 +286,7 @@ used in this specific example, it is possible to add the kind of the tensor as a
argument. For example, a 3-dimensional Tensor of different data types(float, int, bool) would be defined as following: argument. For example, a 3-dimensional Tensor of different data types(float, int, bool) would be defined as following:
```rust , ignore ```rust , ignore
Tensor<B, 3 > // Float tensor (default) Tensor<B, 3> // Float tensor (default)
Tensor<B, 3, Float> // Float tensor (explicit) Tensor<B, 3, Float> // Float tensor (explicit)
Tensor<B, 3, Int> // Int tensor Tensor<B, 3, Int> // Int tensor
Tensor<B, 3, Bool> // Bool tensor Tensor<B, 3, Bool> // Bool tensor

View File

@ -14,7 +14,7 @@ Note that the specific element types used for `Float`, `Int`, and `Bool` tensors
backend implementations. backend implementations.
Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The
actual shape of the tensor is inferred from its initialization. For example, a Tensor of size: (5,) is initialized as actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as
below: below:
```rust, ignore ```rust, ignore
@ -34,8 +34,8 @@ Let's look at a couple of examples for initializing a tensor from different inpu
```rust, ignore ```rust, ignore
// Initialization from a given Backend (WGpu) // Initialization from a given Backend (Wgpu)
let tensor_1 = Tensor::<WGpu, 1>::from_data([1.0, 2.0, 3.0]); let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0]);
// Initialization from a generic Backend // Initialization from a generic Backend
let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert()); let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert());