mirror of https://github.com/tracel-ai/burn.git
Add hidden code snippets to guide example in Burn book [redo] (#1742)
* added hidden code snippets in Burn book guide example * Update backend.md * Revert last commit
This commit is contained in:
parent
adbe97dc4d
commit
e233c38b0f
|
@ -5,19 +5,27 @@ explicitly designated the backend to be used at any point. This will be defined
|
||||||
entrypoint of our program, namely the `main` function defined in `src/main.rs`.
|
entrypoint of our program, namely the `main` function defined in `src/main.rs`.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
use burn::optim::AdamConfig;
|
# mod data;
|
||||||
use burn::backend::{Autodiff, Wgpu, wgpu::AutoGraphicsApi};
|
# mod model;
|
||||||
use crate::model::ModelConfig;
|
# mod training;
|
||||||
|
#
|
||||||
|
use crate::{model::ModelConfig, training::TrainingConfig};
|
||||||
|
use burn::{
|
||||||
|
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
|
||||||
|
# data::dataset::Dataset,
|
||||||
|
optim::AdamConfig,
|
||||||
|
};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
|
||||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
let artifact_dir = "/tmp/guide";
|
||||||
crate::training::train::<MyAutodiffBackend>(
|
crate::training::train::<MyAutodiffBackend>(
|
||||||
"/tmp/guide",
|
artifact_dir,
|
||||||
crate::training::TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
|
TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
|
||||||
device,
|
device.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
|
@ -42,6 +42,22 @@ not all backends expose the same devices. As an example, the Libtorch-based back
|
||||||
Next, we need to actually implement the batching logic.
|
Next, we need to actually implement the batching logic.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use burn::{
|
||||||
|
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
|
# prelude::*,
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# #[derive(Clone)]
|
||||||
|
# pub struct MnistBatcher<B: Backend> {
|
||||||
|
# device: B::Device,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# impl<B: Backend> MnistBatcher<B> {
|
||||||
|
# pub fn new(device: B::Device) -> Self {
|
||||||
|
# Self { device }
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MnistBatch<B: Backend> {
|
pub struct MnistBatch<B: Backend> {
|
||||||
pub images: Tensor<B, 3>,
|
pub images: Tensor<B, 3>,
|
||||||
|
|
|
@ -10,6 +10,16 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic
|
||||||
load our trained model.
|
load our trained model.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use burn::{
|
||||||
|
# config::Config,
|
||||||
|
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
|
# module::Module,
|
||||||
|
# record::{CompactRecorder, Recorder},
|
||||||
|
# tensor::backend::Backend,
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# use crate::{data::MnistBatcher, training::TrainingConfig};
|
||||||
|
#
|
||||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||||
.expect("Config should exist for the model");
|
.expect("Config should exist for the model");
|
||||||
|
@ -39,6 +49,29 @@ By running the infer function, you should see the predictions of your model!
|
||||||
Add the call to `infer` to the `main.rs` file after the `train` function call:
|
Add the call to `infer` to the `main.rs` file after the `train` function call:
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# mod data;
|
||||||
|
# mod inference;
|
||||||
|
# mod model;
|
||||||
|
# mod training;
|
||||||
|
#
|
||||||
|
# use crate::{model::ModelConfig, training::TrainingConfig};
|
||||||
|
# use burn::{
|
||||||
|
# backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
|
||||||
|
# data::dataset::Dataset,
|
||||||
|
# optim::AdamConfig,
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# fn main() {
|
||||||
|
# type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||||
|
# type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||||
|
#
|
||||||
|
# let device = burn::backend::wgpu::WgpuDevice::default();
|
||||||
|
# let artifact_dir = "/tmp/guide";
|
||||||
|
# crate::training::train::<MyAutodiffBackend>(
|
||||||
|
# artifact_dir,
|
||||||
|
# TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
|
||||||
|
# device.clone(),
|
||||||
|
# );
|
||||||
crate::inference::infer::<MyBackend>(
|
crate::inference::infer::<MyBackend>(
|
||||||
artifact_dir,
|
artifact_dir,
|
||||||
device,
|
device,
|
||||||
|
@ -46,6 +79,7 @@ Add the call to `infer` to the `main.rs` file after the `train` function call:
|
||||||
.get(42)
|
.get(42)
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
);
|
);
|
||||||
|
# }
|
||||||
```
|
```
|
||||||
|
|
||||||
The number `42` is the index of the image in the MNIST dataset. You can explore and verify them using
|
The number `42` is the index of the image in the MNIST dataset. You can explore and verify them using
|
||||||
|
|
|
@ -165,11 +165,34 @@ at the top of the main file:
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
mod model;
|
mod model;
|
||||||
|
#
|
||||||
|
# fn main() {
|
||||||
|
# }
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, we need to instantiate the model for training.
|
Next, we need to instantiate the model for training.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use burn::{
|
||||||
|
# nn::{
|
||||||
|
# conv::{Conv2d, Conv2dConfig},
|
||||||
|
# pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||||
|
# Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||||
|
# },
|
||||||
|
# prelude::*,
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# #[derive(Module, Debug)]
|
||||||
|
# pub struct Model<B: Backend> {
|
||||||
|
# conv1: Conv2d<B>,
|
||||||
|
# conv2: Conv2d<B>,
|
||||||
|
# pool: AdaptiveAvgPool2d,
|
||||||
|
# dropout: Dropout,
|
||||||
|
# linear1: Linear<B>,
|
||||||
|
# linear2: Linear<B>,
|
||||||
|
# activation: Relu,
|
||||||
|
# }
|
||||||
|
#
|
||||||
#[derive(Config, Debug)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ModelConfig {
|
pub struct ModelConfig {
|
||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
|
@ -253,6 +276,49 @@ which we will flatten in the forward pass to have a 1024 (16 _ 8 _ 8) resulting
|
||||||
Now let's see how the forward pass is defined.
|
Now let's see how the forward pass is defined.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use burn::{
|
||||||
|
# nn::{
|
||||||
|
# conv::{Conv2d, Conv2dConfig},
|
||||||
|
# pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||||
|
# Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||||
|
# },
|
||||||
|
# prelude::*,
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# #[derive(Module, Debug)]
|
||||||
|
# pub struct Model<B: Backend> {
|
||||||
|
# conv1: Conv2d<B>,
|
||||||
|
# conv2: Conv2d<B>,
|
||||||
|
# pool: AdaptiveAvgPool2d,
|
||||||
|
# dropout: Dropout,
|
||||||
|
# linear1: Linear<B>,
|
||||||
|
# linear2: Linear<B>,
|
||||||
|
# activation: Relu,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# #[derive(Config, Debug)]
|
||||||
|
# pub struct ModelConfig {
|
||||||
|
# num_classes: usize,
|
||||||
|
# hidden_size: usize,
|
||||||
|
# #[config(default = "0.5")]
|
||||||
|
# dropout: f64,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# impl ModelConfig {
|
||||||
|
# /// Returns the initialized model.
|
||||||
|
# pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
|
||||||
|
# Model {
|
||||||
|
# conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
||||||
|
# conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
||||||
|
# pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||||
|
# activation: Relu::new(),
|
||||||
|
# linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
||||||
|
# linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
||||||
|
# dropout: DropoutConfig::new(self.dropout).init(),
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
impl<B: Backend> Model<B> {
|
impl<B: Backend> Model<B> {
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
/// - Images [batch_size, height, width]
|
/// - Images [batch_size, height, width]
|
||||||
|
|
|
@ -15,6 +15,23 @@ beyond the scope of this guide.
|
||||||
Since the MNIST task is a classification problem, we will use the `ClassificationOutput` type.
|
Since the MNIST task is a classification problem, we will use the `ClassificationOutput` type.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use crate::{
|
||||||
|
# data::{MnistBatch, MnistBatcher},
|
||||||
|
# model::{Model, ModelConfig},
|
||||||
|
# };
|
||||||
|
# use burn::{
|
||||||
|
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||||
|
# nn::loss::CrossEntropyLossConfig,
|
||||||
|
# optim::AdamConfig,
|
||||||
|
# prelude::*,
|
||||||
|
# record::CompactRecorder,
|
||||||
|
# tensor::backend::AutodiffBackend,
|
||||||
|
# train::{
|
||||||
|
# metric::{AccuracyMetric, LossMetric},
|
||||||
|
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||||
|
# },
|
||||||
|
# };
|
||||||
|
#
|
||||||
impl<B: Backend> Model<B> {
|
impl<B: Backend> Model<B> {
|
||||||
pub fn forward_classification(
|
pub fn forward_classification(
|
||||||
&self,
|
&self,
|
||||||
|
@ -43,6 +60,42 @@ Moving forward, we will proceed with the implementation of both the training and
|
||||||
for our model.
|
for our model.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use burn::{
|
||||||
|
# config::Config,
|
||||||
|
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||||
|
# module::Module,
|
||||||
|
# nn::loss::CrossEntropyLoss,
|
||||||
|
# optim::AdamConfig,
|
||||||
|
# record::CompactRecorder,
|
||||||
|
# tensor::{
|
||||||
|
# backend::{AutodiffBackend, Backend},
|
||||||
|
# Int, Tensor,
|
||||||
|
# },
|
||||||
|
# train::{
|
||||||
|
# metric::{AccuracyMetric, LossMetric},
|
||||||
|
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||||
|
# },
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# use crate::{
|
||||||
|
# data::{MnistBatch, MnistBatcher},
|
||||||
|
# model::{Model, ModelConfig},
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# impl<B: Backend> Model<B> {
|
||||||
|
# pub fn forward_classification(
|
||||||
|
# &self,
|
||||||
|
# images: Tensor<B, 3>,
|
||||||
|
# targets: Tensor<B, 1, Int>,
|
||||||
|
# ) -> ClassificationOutput<B> {
|
||||||
|
# let output = self.forward(images);
|
||||||
|
# let loss =
|
||||||
|
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
||||||
|
#
|
||||||
|
# ClassificationOutput::new(loss, output, targets)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||||
let item = self.forward_classification(batch.images, batch.targets);
|
let item = self.forward_classification(batch.images, batch.targets);
|
||||||
|
@ -94,6 +147,56 @@ Book.
|
||||||
Let us move on to establishing the practical training configuration.
|
Let us move on to establishing the practical training configuration.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use burn::{
|
||||||
|
# config::Config,
|
||||||
|
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||||
|
# module::Module,
|
||||||
|
# nn::loss::CrossEntropyLoss,
|
||||||
|
# optim::AdamConfig,
|
||||||
|
# record::CompactRecorder,
|
||||||
|
# tensor::{
|
||||||
|
# backend::{AutodiffBackend, Backend},
|
||||||
|
# Int, Tensor,
|
||||||
|
# },
|
||||||
|
# train::{
|
||||||
|
# metric::{AccuracyMetric, LossMetric},
|
||||||
|
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||||
|
# },
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# use crate::{
|
||||||
|
# data::{MnistBatch, MnistBatcher},
|
||||||
|
# model::{Model, ModelConfig},
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# impl<B: Backend> Model<B> {
|
||||||
|
# pub fn forward_classification(
|
||||||
|
# &self,
|
||||||
|
# images: Tensor<B, 3>,
|
||||||
|
# targets: Tensor<B, 1, Int>,
|
||||||
|
# ) -> ClassificationOutput<B> {
|
||||||
|
# let output = self.forward(images);
|
||||||
|
# let loss =
|
||||||
|
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
||||||
|
#
|
||||||
|
# ClassificationOutput::new(loss, output, targets)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
|
# fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||||
|
# let item = self.forward_classification(batch.images, batch.targets);
|
||||||
|
#
|
||||||
|
# TrainOutput::new(self, item.loss.backward(), item)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
|
# fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||||
|
# self.forward_classification(batch.images, batch.targets)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
pub struct TrainingConfig {
|
pub struct TrainingConfig {
|
||||||
pub model: ModelConfig,
|
pub model: ModelConfig,
|
||||||
|
|
Loading…
Reference in New Issue