Go to file
Nathaniel Simard 6f43d983f7
State serialization/deserialization overhaul (#247)
2023-03-23 11:02:46 -04:00
.github chore: prepare release v0.6.0 (#246) 2023-03-21 09:47:37 -04:00
assets Update projects (#29) 2022-09-04 14:22:56 -04:00
burn State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
burn-autodiff fix: add version to path dependencies 2023-03-21 10:13:44 -04:00
burn-common fix: missing info in cargo.toml 2023-03-21 09:56:34 -04:00
burn-core State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
burn-dataset chore: prepare release v0.6.0 (#246) 2023-03-21 09:47:37 -04:00
burn-derive chore: prepare release v0.6.0 (#246) 2023-03-21 09:47:37 -04:00
burn-ndarray State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
burn-no-std-tests fix: add version to path dependencies 2023-03-21 10:13:44 -04:00
burn-tch State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
burn-tensor State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
burn-tensor-testgen chore: prepare release v0.6.0 (#246) 2023-03-21 09:47:37 -04:00
burn-train State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
ci Ci/use template (#110) 2022-11-20 12:29:02 -05:00
examples State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
scripts Fix tch build issue from the root (#180) (#189) 2023-03-01 15:54:37 -05:00
.gitignore Fix tch build issue from the root (#180) (#189) 2023-03-01 15:54:37 -05:00
Cargo.toml State serialization/deserialization overhaul (#247) 2023-03-23 11:02:46 -04:00
LICENSE-APACHE Update projects (#29) 2022-09-04 14:22:56 -04:00
LICENSE-MIT Update projects (#29) 2022-09-04 14:22:56 -04:00
README.md Fix and update readme docs (#244) 2023-03-20 11:51:07 -04:00

README.md

Discord Test Status Documentation Current Crates.io Version Rust Version license

This library aims to be a complete deep learning framework with extreme flexibility written in Rust. The goal would be to satisfy researchers as well as practitioners making it easier to experiment, train and deploy your models.

Disclamer Burn is currently in active development, and there will be breaking changes. While any resulting issues are likely to be easy to fix, there are no guarantees at this stage.

Sections

Features

  • Flexible and intuitive custom neural network module 🔥
  • Training with full support for metric, logging and checkpointing 📈
  • Tensor crate with backends as pluging 🔧
    • Tch backend with CPU/GPU support 🚀
    • NdArray backend with no_std support, running on any platform 👌
    • Autodiff backend making any backend differentiable 🌟
  • Dataset crate with multiple utilities and sources 📚

Get Started

The best way to get started with burn is to clone the repo and play with the examples. This may also be a good idea to take a look the main components of burn to get a quick overview of the fundamental building blocks.

Examples

Components

Understanding the key components and philosophy of burn can greatly help when beginning to work with the framework.

Backend

Nearly everything in burn is based on the Backend trait, which enables you to run tensor operations using different implementations without having to modify your code. While a backend may not necessarily have autodiff capabilities, the ADBackend trait specifies when autodiff is needed. This trait not only abstracts operations but also tensor, device and element types, providing each backend the flexibility they need. It's worth noting that the trait assumes eager mode since burn fully supports dynamic graphs. However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code.

Tensor

At the core of burn lies the Tensor struct, which encompasses multiple types of tensors, including Float, Int, and Bool. The element types of these tensors are specified by the backend and are usually designated as a generic argument (e.g., NdArrayBackend<f32>). Although the same struct is used for all tensors, the available methods differ depending on the tensor kind. You can specify the desired tensor kind by setting the third generic argument, which defaults to Float. The first generic argument specifies the backend, while the second specifies the number of dimensions.

use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, Int};

fn function<B: Backend>(tensor_float: Tensor<B, 2>) {
    let _tensor_bool = tensor_float.clone().equal_elem(2.0); // Tensor<B, 2, Bool>
    let _tensor_int = tensor_float.argmax(1) // Tensor<B, 2, Int>
}

As demonstrated in the previous example, nearly all operations require owned tensors as parameters, which means that calling Clone explicitly is necessary when reusing the same tensor multiple times. However, there's no need to worry since the tensor's data won't be copied, it will be flagged as readonly when multiple tensors use the same allocated memory. This enables backends to reuse tensor data when possible, similar to a copy-on-write pattern, while remaining completely transparent to the user.

Autodiff

The 'Backend' trait is highly flexible, enabling backpropagation to be implemented using a simple backend decorator, which makes any backend differentiable.

use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::{Distribution, Tensor};
use burn_autodiff::ADBackendDecorator;
use burn_ndarray::NdArrayBackend;

fn linear<B: Backend>(x: Tensor<B, 2>, weight: Tensor<B, 2>, bias: Tensor<B, 2>) -> Tensor<B, 2> {
    x.matmul(weight) + bias
}

fn main() {
    type Backend = NdArrayBackend<f32>;

    let weight = Tensor::random([3, 3], Distribution::Standard);
    let bias = Tensor::zeros([1, 3]);
    let x = Tensor::random([3, 3], Distribution::Standard);

    let y = linear::<Backend>(x.clone(), weight.clone(), bias.clone());
    // y.backward() // Method backward doesn't exist

    let y = linear::<ADBackendDecorator<Backend>>(
        Tensor::from_inner(x),
        Tensor::from_inner(weight).require_grad(),
        Tensor::from_inner(bias).require_grad(),
    );
    let grads = y.backward(); // Method exists
}

Module

The Module derive allows you to create your own neural network modules, similar to PyTorch. Note that the Module derive generates all the necessary methods to make your type essentially a parameter container. It makes no assumptions about how the forward function is declared.

use burn::nn;
use burn::module::{Param, Module};
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Param<Linear<B>>,
    linear_outer: Param<Linear<B>>,
    dropout: Dropout,
    gelu: GELU,
}

impl<B: Backend> PositionWiseFeedForward<B> {
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);

        self.linear_outer.forward(x)
    }
}

Note that only the fields wrapped inside Param are updated during training, and the other fields should implement the Clone trait.

Config

The Config derive lets you define serializable and deserializable configurations or hyper-parameters for your modules or any components.

use burn::config::Config;

#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
    pub d_model: usize,
    pub d_ff: usize,
    #[config(default = 0.1)]
    pub dropout: f64,
}

The derive also adds useful methods to your config, similar to a builder pattern.

fn main() {
    let config = PositionWiseFeedForwardConfig::new(512, 2048);
    println!("{}", config.d_model); // 512
    println!("{}", config.d_ff); // 2048
    println!("{}", config.dropout); // 0.1
    let config =  config.with_dropout(0.2);
    println!("{}", config.dropout); // 0.2
}

Learner

The Learner is the main struct that let you train a neural network with support for logging, metric, checkpointing and more. In order to create a learner, you must use the LearnerBuilder.

use burn::train::LearnerBuilder;
use burn::train::metric::{AccuracyMetric, LossMetric};

fn main() {
    let dataloader_train = ...;
    let dataloader_valid = ...;

    let model = ...;
    let optim = ...;

    let learner = LearnerBuilder::new("/tmp/artifact_dir")
        .metric_train_plot(AccuracyMetric::new())
        .metric_valid_plot(AccuracyMetric::new())
        .metric_train(LossMetric::new())
        .metric_valid(LossMetric::new())
        .with_file_checkpointer::<f32>(2)
        .num_epochs(10)
        .build(model, optim);

    let _model_trained = learner.fit(dataloader_train, dataloader_valid);
}

See this example for a real usage.

no_std support

Burn supports no_std with alloc for the inference mode with the NDArray backend. Simply disable the default features of the burn and burn-ndarray crates (minimum required to run the inference mode). See the burn-no-std-tests example as a reference implementation.

Additionally burn-core and burn-tensor crates support no_std with alloc if needed to direclty include them as dependencies (the burn crates reexports burn-core and burn-tensor). Note, under the no_std mode, a random seed is generated during the build time if the seed is not initialized by Backend::seed method. Additionally, spin::mutex::Mutex is used in place of std::sync::Mutex under the no_std mode.

License

Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.