*__Disclamer__* _Burn is currently in active development, and there will be breaking changes. While any resulting issues are likely to be easy to fix, there are no guarantees at this stage._
__Sections__
* [Features](#features)
* [Get Started](#get-started)
* [Examples](#examples)
* [Components](#components)
* [Backend](#backend)
* [Tensor](#tensor)
* [Module](#module)
* [Config](#config)
* [Learner](#learner)
* [no_std support](#no_std-support)
* [License](#license)
## Features
* Flexible and intuitive custom neural network [module](#module) 🔥
* [Training](#learner) with full support for `metric`, `logging` and `checkpointing` 📈
* [Tensor](#tensor) crate with backends as pluging 🔧
* [Tch](https://github.com/burn-rs/burn/tree/main/burn-tch) backend with CPU/GPU support 🚀
* [NdArray](https://github.com/burn-rs/burn/tree/main/burn-ndarray) backend with fast compile time 👌
* [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend making any backend differentiable 🌟
* [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate with multiple utilities and sources 📚
## Get Started
The best way to get started with `burn` is to clone the repo and play with the [examples](#examples).
This may also be a good idea to take a look the main [components](#components) of `burn` to get a quick overview of the fundamental building blocks.
### Examples
* [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends.
* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU.
### Components
Knowing the main components will be of great help when starting playing with `burn`.
#### Backend
Almost everything is based on the `Backend` trait, which allows to run tensor operations with different implementations without having to change your code.
A backend does not necessary have autodiff capabilities, the `ADBackend` trait is there to specify when autodiff is required.
#### Tensor
The `Tensor` struct is at the core of the `burn` framework.
It takes two generic parameters, the `Backend` and the number of dimensions `D`,
Backpropagation is also supported on any backend by making them auto differentiable using a simple decorator.
```rust
use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::{Distribution, Tensor};
use burn_autodiff::ADBackendDecorator;
use burn_ndarray::NdArrayBackend;
use burn_tch::TchBackend;
fn simple_function
() -> Tensor {
let x = Tensor::::random([3, 3], Distribution::Standard);
let y = Tensor::::random([3, 3], Distribution::Standard);
x.matmul(&y)
}
fn simple_function_grads() -> B::Gradients {
let z = simple_function::();
z.backward()
}
fn main() {
let _z = simple_function::>(); // Compiles
let _z = simple_function::>(); // Compiles
let _grads = simple_function_grads::>(); // Doesn't compile
let _grads = simple_function_grads::>(); // Doesn't compile
type ADNdArrayBackend = ADBackendDecorator>;
type ADTchBackend = ADBackendDecorator>;
let _grads = simple_function_grads::(); // Compiles
let _grads = simple_function_grads::(); // Compiles
}
```
#### Module
The `Module` derive let your create your own neural network modules similar to PyTorch.
```rust
use burn::nn;
use burn::module::{Param, Module};
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
struct MyModule {
my_param: Param>,
repeat: usize,
}
```
Note that only the fields wrapped inside `Param` are updated during training, and the other ones should implement `Clone`.
#### Config
The `Config` derive lets you define serializable and deserializable configurations or hyper-parameters for your [modules](#module) or any components.
```rust
use burn::config::Config;
#[derive(Config)]
struct MyConfig {
#[config(default = 1.0e-6)]
pub epsilon: usize,
pub dim: usize,
}
```
The derive also adds useful methods to your config.
```rust
fn main() {
let config = MyConfig::new(100);
println!("{}", config.epsilon); // 1.0.e-6
println!("{}", config.dim); // 100
let config = MyConfig::new(100).with_epsilon(1.0e-8);
println!("{}", config.epsilon); // 1.0.e-8
}
```
#### Learner
The `Learner` is the main `struct` that let you train a neural network with support for `logging`, `metric`, `checkpointing` and more.
In order to create a learner, you must use the `LearnerBuilder`.
```rust
use burn::train::LearnerBuilder;
use burn::train::metric::{AccuracyMetric, LossMetric};
fn main() {
let dataloader_train = ...;
let dataloader_valid = ...;
let model = ...;
let optim = ...;
let learner = LearnerBuilder::new("/tmp/artifact_dir")
.metric_train_plot(AccuracyMetric::new())
.metric_valid_plot(AccuracyMetric::new())
.metric_train(LossMetric::new())
.metric_valid(LossMetric::new())
.with_file_checkpointer::(2)
.num_epochs(10)
.build(model, optim);
let _model_trained = learner.fit(dataloader_train, dataloader_valid);
}
```
See this [example](https://github.com/burn-rs/burn/tree/main/examples/mnist) for a real usage.
## no_std support
Burn supports `no_std` with `alloc` for the inference mode with the NDArray backend. Simply disable the default features of the `burn` and `burn-ndarray` packages (minimum required to run the inference mode). See the [burn-no-std-tests](https://github.com/burn-rs/burn/tree/main/examples/burn-no-std-tests) package as a reference implementation. Additionally `burn-core` and `burn-tensor` packages support `no_std` with `alloc` if needed to direclty include them as dependencies (the `burn` package reexports `burn-core` and `burn-tensor`).
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method. Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
## License
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details.
Opening a pull request is assumed to signal agreement with these licensing terms.