Go to file
Dilshod Tadjibaev aa8c96d3fb
Update readme with blas-accelerate flag information (#236)
2023-03-15 21:44:13 -04:00
.github Add new packages and tests with no-std options in CI (#168) (#179) 2023-02-26 14:18:13 -05:00
assets Update projects (#29) 2022-09-04 14:22:56 -04:00
burn Make burn and burn-core packages no_std compatible (#168) (#173) 2023-02-25 09:38:01 -05:00
burn-autodiff Perf/optimize backward ops (#232) 2023-03-14 12:59:52 -04:00
burn-common feat: inplace tensor api. (#187) 2023-03-01 10:55:51 -05:00
burn-core feat: pre-norm transformer (#230) 2023-03-13 19:09:35 -04:00
burn-dataset Fix FailToDownloadPythonDependencies error (#82) (#185) 2023-02-28 09:23:49 -05:00
burn-derive refactor(burn-core): module visitor mut (#195) 2023-03-05 14:40:47 -05:00
burn-ndarray Update readme with blas-accelerate flag information (#236) 2023-03-15 21:44:13 -04:00
burn-no-std-tests Refactor/tensor api (#191) 2023-03-05 09:23:42 -05:00
burn-tch Feat/index_select (#227) 2023-03-12 17:44:22 -04:00
burn-tensor Add MNIST inference on the web demo crate (#228) 2023-03-13 19:51:32 -04:00
burn-tensor-testgen Make burn and burn-core packages no_std compatible (#168) (#173) 2023-02-25 09:38:01 -05:00
burn-train fix(burn-train): use single device loop (#212) 2023-03-08 19:35:10 -05:00
ci Ci/use template (#110) 2022-11-20 12:29:02 -05:00
examples Mnist inference web: Readme (#233) 2023-03-15 08:49:59 -04:00
scripts Fix tch build issue from the root (#180) (#189) 2023-03-01 15:54:37 -05:00
.gitignore Fix tch build issue from the root (#180) (#189) 2023-03-01 15:54:37 -05:00
Cargo.toml Refactor/int backend (#197) 2023-03-06 14:45:58 -05:00
LICENSE-APACHE Update projects (#29) 2022-09-04 14:22:56 -04:00
LICENSE-MIT Update projects (#29) 2022-09-04 14:22:56 -04:00
README.md Mnist inference web: Readme (#233) 2023-03-15 08:49:59 -04:00

README.md

Discord Test Status Documentation Current Crates.io Version Rust Version license

This library aims to be a complete deep learning framework with extreme flexibility written in Rust. The goal would be to satisfy researchers as well as practitioners making it easier to experiment, train and deploy your models.

Disclamer Burn is currently in active development, and there will be breaking changes. While any resulting issues are likely to be easy to fix, there are no guarantees at this stage.

Sections

Features

  • Flexible and intuitive custom neural network module 🔥
  • Training with full support for metric, logging and checkpointing 📈
  • Tensor crate with backends as pluging 🔧
    • Tch backend with CPU/GPU support 🚀
    • NdArray backend with fast compile time 👌
    • Autodiff backend making any backend differentiable 🌟
  • Dataset crate with multiple utilities and sources 📚

Get Started

The best way to get started with burn is to clone the repo and play with the examples. This may also be a good idea to take a look the main components of burn to get a quick overview of the fundamental building blocks.

Examples

Components

Knowing the main components will be of great help when starting playing with burn.

Backend

Almost everything is based on the Backend trait, which allows to run tensor operations with different implementations without having to change your code. A backend does not necessary have autodiff capabilities, the ADBackend trait is there to specify when autodiff is required.

Tensor

The Tensor struct is at the core of the burn framework. It takes two generic parameters, the Backend and the number of dimensions D,

Backpropagation is also supported on any backend by making them auto differentiable using a simple decorator.

use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::{Distribution, Tensor};
use burn_autodiff::ADBackendDecorator;
use burn_ndarray::NdArrayBackend;
use burn_tch::TchBackend;

fn simple_function<B: Backend>() -> Tensor<B, 2> {
    let x = Tensor::<B, 2>::random([3, 3], Distribution::Standard);
    let y = Tensor::<B, 2>::random([3, 3], Distribution::Standard);

    x.matmul(&y)
}

fn simple_function_grads<B: ADBackend>() -> B::Gradients {
    let z = simple_function::<B>();

    z.backward()
}

fn main() {
    let _z = simple_function::<NdArrayBackend<f32>>(); // Compiles
    let _z = simple_function::<TchBackend<f32>>(); // Compiles

    let _grads = simple_function_grads::<NdArrayBackend<f32>>(); // Doesn't compile
    let _grads = simple_function_grads::<TchBackend<f32>>(); // Doesn't compile

    type ADNdArrayBackend = ADBackendDecorator<NdArrayBackend<f32>>;
    type ADTchBackend = ADBackendDecorator<TchBackend<f32>>;

    let _grads = simple_function_grads::<ADNdArrayBackend>(); // Compiles
    let _grads = simple_function_grads::<ADTchBackend>(); // Compiles
}

Module

The Module derive let your create your own neural network modules similar to PyTorch.

use burn::nn;
use burn::module::{Param, Module};
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
struct MyModule<B: Backend> {
  my_param: Param<nn::Linear<B>>,
  repeat: usize,
}

Note that only the fields wrapped inside Param are updated during training, and the other ones should implement Clone.

Config

The Config derive lets you define serializable and deserializable configurations or hyper-parameters for your modules or any components.

use burn::config::Config;

#[derive(Config)]
struct MyConfig {
    #[config(default = 1.0e-6)]
    pub epsilon: usize,
    pub dim: usize,
}

The derive also adds useful methods to your config.

fn main() {
    let config = MyConfig::new(100);
    println!("{}", config.epsilon); // 1.0.e-6
    println!("{}", config.dim); // 100
    let config =  MyConfig::new(100).with_epsilon(1.0e-8);
    println!("{}", config.epsilon); // 1.0.e-8
}

Learner

The Learner is the main struct that let you train a neural network with support for logging, metric, checkpointing and more. In order to create a learner, you must use the LearnerBuilder.

use burn::train::LearnerBuilder;
use burn::train::metric::{AccuracyMetric, LossMetric};

fn main() {
    let dataloader_train = ...;
    let dataloader_valid = ...;

    let model = ...;
    let optim = ...;

    let learner = LearnerBuilder::new("/tmp/artifact_dir")
        .metric_train_plot(AccuracyMetric::new())
        .metric_valid_plot(AccuracyMetric::new())
        .metric_train(LossMetric::new())
        .metric_valid(LossMetric::new())
        .with_file_checkpointer::<f32>(2)
        .num_epochs(10)
        .build(model, optim);

    let _model_trained = learner.fit(dataloader_train, dataloader_valid);
}

See this example for a real usage.

no_std support

Burn supports no_std with alloc for the inference mode with the NDArray backend. Simply disable the default features of the burn and burn-ndarray packages (minimum required to run the inference mode). See the burn-no-std-tests package as a reference implementation. Additionally burn-core and burn-tensor packages support no_std with alloc if needed to direclty include them as dependencies (the burn package reexports burn-core and burn-tensor).

Note, under the no_std mode, a random seed is generated during the build time if the seed is not initialized by Backend::seed method. Additionally, spin::mutex::Mutex is used in place of std::sync::Mutex under the no_std mode.

License

Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.