mirror of https://github.com/tracel-ai/burn.git
Merge pull request #1183 from tracel-ai/docs/guide
Docs/guide Improve guide book: Tensor, Dataset and Example Section.
This commit is contained in:
commit
8b4038d004
|
@ -4,7 +4,14 @@ This guide will walk you through the process of creating a custom model built wi
|
|||
train a simple convolutional neural network model on the MNIST dataset and prepare it for inference.
|
||||
|
||||
For clarity, we sometimes omit imports in our code snippets. For more details, please refer to the
|
||||
corresponding code in the `examples/guide` directory.
|
||||
corresponding code in the `examples/guide` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide). We
|
||||
reproduce this example in a step-by-step fashion, from dataset creation to modeling and training in the following
|
||||
sections.
|
||||
The code for this demo can be executed from Burn's base directory using the command:
|
||||
|
||||
```bash
|
||||
cargo run --example guide
|
||||
```
|
||||
|
||||
## Key Learnings
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
We have effectively written most of the necessary code to train our model. However, we have not
|
||||
explicitly designated the backend to be used at any point. This will be defined in the main
|
||||
entrypoint of our program, namely the `main` function below.
|
||||
entrypoint of our program, namely the `main` function defined in `src/main.rs`.
|
||||
|
||||
```rust , ignore
|
||||
use burn::optim::AdamConfig;
|
||||
|
@ -28,7 +28,9 @@ fn main() {
|
|||
You might be wondering why we use the `guide` prefix to bring the different modules we just
|
||||
implemented into scope. Instead of including the code in the current guide in a single file, we
|
||||
separated it into different files which group related code into _modules_. The `guide` is simply the
|
||||
name we gave to our _crate_, which contains the different files. Below is a brief explanation of the
|
||||
name we gave to our _crate_, which contains the different files. If you named your project crate
|
||||
as `my-first-burn-model`,
|
||||
you can equivalently replace all usages of `guide` above with `my-first-burn-model`. Below is a brief explanation of the
|
||||
different parts of the Rust module system.
|
||||
|
||||
A **package** is a bundle of one or more crates that provides a set of functionality. A package
|
||||
|
@ -38,13 +40,23 @@ A **crate** is a compilation unit in Rust. It could be a single file, but it is
|
|||
split up crates into multiple _modules_ and possibly multiple files. A crate can come in one of two
|
||||
forms: a binary crate or a library crate. When compiling a crate, the compiler first looks in the
|
||||
crate root file (usually `src/lib.rs` for a library crate or `src/main.rs` for a binary crate). Any
|
||||
module declared in the crate root file will be inserted in the crate for compilation.
|
||||
module declared in the crate root file will be inserted in the crate for compilation. For this demo example, we will
|
||||
define a library crate where all the individual modules (model, data, training, etc.) are listed inside `src/lib.rs` as
|
||||
follows:
|
||||
|
||||
```
|
||||
pub mod data;
|
||||
pub mod inference;
|
||||
pub mod model;
|
||||
pub mod training;
|
||||
```
|
||||
|
||||
A **module** lets us organize code within a crate for readability and easy reuse. Modules also allow
|
||||
us to control the _privacy_ of items.
|
||||
us to control the _privacy_ of items. The `pub` keyword used above, for example, is employed to make a module publicly
|
||||
available inside the crate.
|
||||
|
||||
For this guide, we defined a library crate with a single example where the `main` function is
|
||||
defined, as illustrated in the structure below.
|
||||
The entry point of our program is the `main` function, defined in the `examples/guide.rs` file. The file structure
|
||||
for this example is illustrated below:
|
||||
|
||||
```
|
||||
guide
|
||||
|
@ -60,7 +72,8 @@ guide
|
|||
```
|
||||
|
||||
The source for this guide can be found in our
|
||||
[GitHub repository](https://github.com/tracel-ai/burn/tree/main/examples/guide).\
|
||||
[GitHub repository](https://github.com/tracel-ai/burn/tree/main/examples/guide) which can be used to run this basic
|
||||
workflow example end-to-end.\
|
||||
|
||||
</details><br>
|
||||
|
||||
|
|
|
@ -9,6 +9,11 @@ To iterate over a dataset efficiently, we will define a struct which will implem
|
|||
trait. The goal of a batcher is to map individual dataset items into a batched tensor that can be
|
||||
used as input to our previously defined model.
|
||||
|
||||
Let us start by defining our dataset functionalities in a file `src/data.rs`. We shall omit some of the imports for
|
||||
brevity,
|
||||
but the full code for following this guide can be found
|
||||
at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide).
|
||||
|
||||
```rust , ignore
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
|
|
|
@ -6,7 +6,7 @@ For loading a model primed for inference, it is of course more efficient to dire
|
|||
weights into the model, bypassing the need to initially set arbitrary weights or worse, weights
|
||||
computed from a Xavier normal initialization only to then promptly replace them with the stored
|
||||
weights. With that in mind, let's create a new initialization function receiving the record as
|
||||
input.
|
||||
input. This new function can be defined alongside the `init` function for the `ModelConfig` struct in `src/model.rs`.
|
||||
|
||||
```rust , ignore
|
||||
impl ModelConfig {
|
||||
|
@ -30,7 +30,7 @@ It is important to note that the `ModelRecord` was automatically generated thank
|
|||
trait. It allows us to load the module state without having to deal with fetching the correct type
|
||||
manually. Everything is validated when loading the model with the record.
|
||||
|
||||
Now let's create a simple `infer` method in which we will load our trained model.
|
||||
Now let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model.
|
||||
|
||||
```rust , ignore
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
||||
|
|
|
@ -19,14 +19,14 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
burn = { version = "0.12.0", features=["train", "wgpu"]}
|
||||
burn = { version = "0.12.0", features = ["train", "wgpu"] }
|
||||
```
|
||||
|
||||
Our goal will be to create a basic convolutional neural network used for image classification. We
|
||||
will keep the model simple by using two convolution layers followed by two linear layers, some
|
||||
pooling and ReLU activations. We will also use dropout to improve training performance.
|
||||
|
||||
Let us start by creating a model in a file `model.rs`.
|
||||
Let us start by defining our model struct in a new file `src/model.rs`.
|
||||
|
||||
```rust , ignore
|
||||
use burn::{
|
||||
|
@ -281,9 +281,9 @@ network modules already built with Burn use the `forward` nomenclature, simply b
|
|||
standard in the field.
|
||||
|
||||
Similar to neural network modules, the [`Tensor`](../building-blocks/tensor.md) struct given as a
|
||||
parameter also takes the Backend trait as a generic argument, alongside its rank. Even if it is not
|
||||
parameter also takes the Backend trait as a generic argument, alongside its dimensionality. Even if it is not
|
||||
used in this specific example, it is possible to add the kind of the tensor as a third generic
|
||||
argument.
|
||||
argument. For example, a 3-dimensional Tensor of different data types(float, int, bool) would be defined as following:
|
||||
|
||||
```rust , ignore
|
||||
Tensor<B, 3> // Float tensor (default)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Training
|
||||
|
||||
We are now ready to write the necessary code to train our model on the MNIST dataset. Instead of a
|
||||
simple tensor, the model should output an item that can be understood by the learner, a struct whose
|
||||
We are now ready to write the necessary code to train our model on the MNIST dataset.
|
||||
We shall define the code for this training section in the file: `src/training.rs`.
|
||||
|
||||
Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose
|
||||
responsibility is to apply an optimizer to the model. The output struct is used for all metrics
|
||||
calculated during the training. Therefore it should include all the necessary information to
|
||||
calculate any metric that you want for a task.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Dataset
|
||||
|
||||
Most deep learning training being done on datasets –with perhaps the exception of reinforcement learning–, it is essential to provide a convenient and performant API.
|
||||
Most deep learning training being done on datasets –with perhaps the exception of reinforcement learning–, it is
|
||||
essential to provide a convenient and performant API.
|
||||
The dataset trait is quite similar to the dataset abstract class in PyTorch:
|
||||
|
||||
```rust, ignore
|
||||
|
@ -25,20 +26,72 @@ transformations is to provide you with the necessary tools so that you can model
|
|||
distributions.
|
||||
|
||||
| Transformation | Description |
|
||||
| ----------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||
|-------------------|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| `SamplerDataset` | Samples items from a dataset. This is a convenient way to model a dataset as a probability distribution of a fixed size. |
|
||||
| `ShuffledDataset` | Maps each input index to a random index, similar to a dataset sampled without replacement. |
|
||||
| `PartialDataset` | Returns a view of the input dataset with a specified range. |
|
||||
| `MapperDataset` | Computes a transformation lazily on the input dataset. |
|
||||
| `ComposedDataset` | Composes multiple datasets together to create a larger one without copying any data. |
|
||||
|
||||
Let us look at the basic usages of each dataset transform and how they can be composed together. These transforms
|
||||
are lazy by default except when specified, reducing the need for unnecessary intermediate allocations and improving
|
||||
performance. The full documentation of each transform can be found at
|
||||
the [API reference](https://burn.dev/docs/burn/data/dataset/transform/index.html).
|
||||
|
||||
* **SamplerDataset**: This transform can be used to sample items from a dataset with (default) or without replacement.
|
||||
Transform is initialized with a sampling size which can be bigger or smaller than the input dataset size. This is
|
||||
particularly useful in cases where we want to checkpoint larger datasets more often during training
|
||||
and smaller datasets less often as the size of an epoch is now controlled by the sampling size. Sample usage:
|
||||
|
||||
```rust, ignore
|
||||
type DbPedia = SqliteDataset<DbPediaItem>;
|
||||
let dataset: DbPedia = HuggingfaceDatasetLoader::new("dbpedia_14")
|
||||
.dataset("train").
|
||||
.unwrap();
|
||||
|
||||
let dataset = SamplerDataset<DbPedia, DbPediaItem>::new(dataset, 10000);
|
||||
```
|
||||
|
||||
* **ShuffledDataset**: This transform can be used to shuffle the items of a dataset. Particularly useful before
|
||||
splitting
|
||||
the raw dataset into train/test splits. Can be initialized with a seed to ensure reproducibility.
|
||||
|
||||
```rust, ignore
|
||||
let dataset = ShuffledDataset<DbPedia, DbPediaItem>::with_seed(dataset, 42);
|
||||
```
|
||||
|
||||
* **PartialDataset**: This transform is useful to return a view of the dataset with specified start and end indices.
|
||||
Used
|
||||
to create train/val/test splits. In the example below, we show how to chain ShuffledDataset and PartialDataset to
|
||||
create
|
||||
splits.
|
||||
|
||||
```rust, ignore
|
||||
// define chained dataset type here for brevity
|
||||
type PartialData = PartialDataset<ShuffledDataset<DbPedia, DbPediaItem>>;
|
||||
let dataset_len = dataset.len();
|
||||
let split == "train"; // or "val"/"test"
|
||||
|
||||
let data_split = match split {
|
||||
"train" => PartialData::new(dataset, 0, len * 8 / 10), // Get first 80% dataset
|
||||
"test" => PartialData::new(dataset, len * 8 / 10, len), // Take remaining 20%
|
||||
_ => panic!("Invalid split type"), // Handle unexpected split types
|
||||
};
|
||||
```
|
||||
|
||||
* **MapperDataset**: This transform is useful to apply a transformation on each of the items of a dataset. Particularly
|
||||
useful for normalization of image data when channel means are known.
|
||||
|
||||
* **ComposedDataset**: This transform is useful to compose multiple datasets downloaded from multiple sources (say
|
||||
different HuggingfaceDatasetLoader sources) into a single bigger dataset which can be sampled from one source.
|
||||
|
||||
## Storage
|
||||
|
||||
There are multiple dataset storage options available for you to choose from. The choice of the
|
||||
dataset to use should be based on the dataset's size as well as its intended purpose.
|
||||
|
||||
| Storage | Description |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
|-----------------|---------------------------------------------------------------------------------------------------------------------------|
|
||||
| `InMemDataset` | In-memory dataset that uses a vector to store items. Well-suited for smaller datasets. |
|
||||
| `SqliteDataset` | Dataset that uses SQLite to index items that can be saved in a simple SQL database file. Well-suited for larger datasets. |
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Tensor
|
||||
|
||||
As previously explained in the [model section](../basic-workflow/model.md), the Tensor struct has 3
|
||||
generic arguments: the backend, the number of dimensions (rank) and the data type.
|
||||
generic arguments: the backend B, the dimensionality D, and the data type.
|
||||
|
||||
```rust , ignore
|
||||
Tensor<B, D> // Float tensor (default)
|
||||
|
@ -13,14 +13,104 @@ Tensor<B, D, Bool> // Bool tensor
|
|||
Note that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by
|
||||
backend implementations.
|
||||
|
||||
## Operations
|
||||
Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The
|
||||
actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as
|
||||
below:
|
||||
|
||||
```rust, ignore
|
||||
// correct: Tensor is 1-Dimensional with 5 elements
|
||||
let tensor_1 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
|
||||
// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
// this will lead to an error and is for creating a 5-D tensor
|
||||
```
|
||||
|
||||
### Initialization
|
||||
|
||||
Burn Tensors are primarily initialized using the `from_data()` method which takes the `Data` struct as input.
|
||||
The `Data` struct has two fields: value & shape. To retrieve the data from a tensor, the method `.to_data()` should be
|
||||
employed when intending to reuse the tensor afterward. Alternatively, `.into_data()` is recommended for one-time use.
|
||||
Let's look at a couple of examples for initializing a tensor from different inputs.
|
||||
|
||||
```rust, ignore
|
||||
|
||||
// Initialization from a given Backend (Wgpu)
|
||||
let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0]);
|
||||
|
||||
// Initialization from a generic Backend
|
||||
let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert());
|
||||
|
||||
// Initialization using from_floats (Recommended for f32 ElementType)
|
||||
// Will be converted to Data internally. `.convert()` not needed as from_floats() defined for fixed ElementType
|
||||
let tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0]);
|
||||
|
||||
// Initalization of Int Tensor from array slices
|
||||
let arr: [i32; 6] = [1, 2, 3, 4, 5, 6];
|
||||
let tensor_4 = Tensor::<Backend, 1, Int>::from_data(Data::from(&arr[0..3]).convert());
|
||||
|
||||
// Initialization from a custom type
|
||||
|
||||
struct BodyMetrics {
|
||||
age: i8,
|
||||
height: i16,
|
||||
weight: f32
|
||||
}
|
||||
|
||||
let bmi = BodyMetrics{
|
||||
age: 25,
|
||||
height: 180,
|
||||
weight: 80.0
|
||||
};
|
||||
let tensor_5 = Tensor::<Backend, 1>::from_data(Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert());
|
||||
|
||||
```
|
||||
|
||||
The `.convert()` method for Data struct is called to ensure that the data's primitive type is
|
||||
consistent across all backends. With `.from_floats()` method the ElementType is fixed as f32
|
||||
and therefore no convert operation is required across backends. This operation can also be done at element wise
|
||||
level as:
|
||||
`let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()])`. The `ElementConversion` trait
|
||||
however needs to be imported for the element wise operation.
|
||||
|
||||
## Ownership and Cloning
|
||||
|
||||
Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple
|
||||
times will necessitate cloning it. Don't worry, the tensor's buffer isn't copied, but a reference to
|
||||
it is increased. This makes it possible to determine exactly how many times a tensor is used, which
|
||||
is very convenient for reusing tensor buffers and improving performance. For that reason, we don't
|
||||
provide explicit inplace operations. If a tensor is used only one time, inplace operations will
|
||||
always be used when available.
|
||||
times will necessitate cloning it. Let's look at an example to understand the ownership rules and cloning better.
|
||||
Suppose we want to do a simple min-max normalization of an input tensor.
|
||||
|
||||
```rust, ignore
|
||||
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
|
||||
let min = input.min();
|
||||
let max = input.max();
|
||||
let input = (input - min).div(max - min);
|
||||
```
|
||||
|
||||
With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules will give an error
|
||||
and prevent using the input tensor after the first `.min()` operation. The ownership of the input tensor is transferred
|
||||
to the variable `min` and the input tensor is no longer available for further operations. Burn Tensors like most
|
||||
complex primitives do not implement the `Copy` trait and therefore have to be cloned explicitly. Now let's rewrite
|
||||
a working example of doing min-max normalization with cloning.
|
||||
|
||||
```rust, ignore
|
||||
let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0]);
|
||||
let min = input.clone().min();
|
||||
let max = input.clone().max();
|
||||
let input = (input.clone() - min.clone()).div(max - min);
|
||||
println!("{:?}", input.to_data()); // Success: [0.0, 0.33333334, 0.6666667, 1.0]
|
||||
|
||||
// Notice that max, min have been moved in last operation so the below print will give an error.
|
||||
// If we want to use them for further operations, they will need to be cloned in similar fashion.
|
||||
// println!("{:?}", min.to_data());
|
||||
```
|
||||
|
||||
We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't copied,
|
||||
and only a reference to it is increased. This makes it possible to determine exactly how many times a tensor is used,
|
||||
which is very convenient for reusing tensor buffers or even fusing operations into a single
|
||||
kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)).
|
||||
For that reason, we don't provide explicit inplace operations. If a tensor is used only one time, inplace operations
|
||||
will always be used when available.
|
||||
|
||||
## Tensor Operations
|
||||
|
||||
Normally with PyTorch, explicit inplace operations aren't supported during the backward pass, making
|
||||
them useful only for data preprocessing or inference-only model implementations. With Burn, you can
|
||||
|
@ -37,7 +127,7 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
|
|||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
| ------------------------------------- | ------------------------------------ |
|
||||
|---------------------------------------|--------------------------------------|
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
|
@ -67,7 +157,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
| ---------------------------------------------------------------- | ---------------------------------------------- |
|
||||
|------------------------------------------------------------------|------------------------------------------------|
|
||||
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
|
||||
| `tensor + other` or `tensor.add(other)` | `tensor + other` |
|
||||
| `tensor + scalar` or `tensor.add_scalar(scalar)` | `tensor + scalar` |
|
||||
|
@ -123,7 +213,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
Those operations are only available for `Float` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| -------------------------------------------- | ---------------------------------- |
|
||||
|----------------------------------------------|------------------------------------|
|
||||
| `tensor.exp()` | `tensor.exp()` |
|
||||
| `tensor.log()` | `tensor.log()` |
|
||||
| `tensor.log1p()` | `tensor.log1p()` |
|
||||
|
@ -155,7 +245,7 @@ Those operations are only available for `Float` tensors.
|
|||
Those operations are only available for `Int` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| -------------------------------------- | ------------------------------------------------------- |
|
||||
|----------------------------------------|---------------------------------------------------------|
|
||||
| `tensor.from_ints(ints)` | N/A |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
|
||||
|
@ -166,7 +256,7 @@ Those operations are only available for `Int` tensors.
|
|||
Those operations are only available for `Bool` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------- | ----------------------------------- |
|
||||
|------------------|-------------------------------------|
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||
| `tensor.not()` | `tensor.logical_not()` |
|
||||
|
@ -174,7 +264,7 @@ Those operations are only available for `Bool` tensors.
|
|||
## Activation Functions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------------------------------- | ----------------------------------------------------- |
|
||||
|------------------------------------------|-------------------------------------------------------|
|
||||
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
|
||||
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
|
||||
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
|
||||
|
|
|
@ -28,7 +28,7 @@ libraries/packages your code depends on, and build said libraries.
|
|||
Below is a quick cheat sheet of the main `cargo` commands you might use throughout this guide.
|
||||
|
||||
| Command | Description |
|
||||
| ------------------- | -------------------------------------------------------------------------------------------- |
|
||||
|---------------------|----------------------------------------------------------------------------------------------|
|
||||
| `cargo new` _path_ | Create a new Cargo package in the given directory. |
|
||||
| `cargo add` _crate_ | Add dependencies to the Cargo.toml manifest file. |
|
||||
| `cargo build` | Compile the local package and all of its dependencies (in debug mode, use `-r` for release). |
|
||||
|
@ -165,7 +165,7 @@ Tensor {
|
|||
```
|
||||
|
||||
While the previous example is somewhat trivial, the upcoming
|
||||
[basic workflow section](./basic-workflow/) will walk you through a much more relevant example for
|
||||
basic workflow section will walk you through a much more relevant example for
|
||||
deep learning applications.
|
||||
|
||||
## Running examples
|
||||
|
|
Loading…
Reference in New Issue