Print module part3 - Update book (#1940)

* Update book example guide

* Update Module book section on module display
This commit is contained in:
Dilshod Tadjibaev 2024-07-01 12:42:17 -05:00 committed by GitHub
parent 3a9367de73
commit 6f2ba34382
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 164 additions and 23 deletions

View File

@ -192,7 +192,7 @@ Next, we need to instantiate the model for training.
# linear2: Linear<B>,
# activation: Relu,
# }
#
#
#[derive(Config, Debug)]
pub struct ModelConfig {
num_classes: usize,
@ -217,6 +217,40 @@ impl ModelConfig {
}
```
At a glance, you can view the model configuration by printing the model instance:
```rust , ignore
use burn::backend::Wgpu;
use guide::model::ModelConfig;
fn main() {
type MyBackend = Wgpu<f32, i32>;
let device = Default::default();
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);
println!("{}", model);
}
```
Output:
```rust , ignore
Model {
conv1: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80}
conv2: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168}
pool: AdaptiveAvgPool2d {output_size: [8, 8]}
dropout: Dropout {prob: 0.5}
linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800}
linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130}
activation: Relu
params: 531178
}
```
<details>
<summary><strong>🦀 References</strong></summary>

View File

@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
These methods are available for all modules.
| Burn API | PyTorch Equivalent |
|-----------------------------------------|------------------------------------------|
| --------------------------------------- | ---------------------------------------- |
| `module.devices()` | N/A |
| `module.fork(device)` | Similar to `module.to(device).detach()` |
| `module.to_device(device)` | `module.to(device)` |
@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
autodiff support.
| Burn API | PyTorch Equivalent |
|------------------|--------------------|
| ---------------- | ------------------ |
| `module.valid()` | `module.eval()` |
## Visitor & Mapper
@ -96,7 +96,62 @@ pub trait ModuleVisitor<B: Backend> {
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) ->
Tensor<B, D>;
}
```
## Module Display
Burn provides a simple way to display the structure of a module and its configuration at a glance.
You can print the module to see its structure, which is useful for debugging and tracking changes
across different versions of a module. (See the print output of the
[Basic Workflow Model](../basic-workflow/model.md) example.)
To customize the display of a module, you can implement the `ModuleDisplay` trait for your module.
This will change the default display settings for the module and its children. Note that
`ModuleDisplay` is automatically implemented for all modules, but you can override it to customize
the display by annotating the module with `#[module(custom_display)]`.
```rust
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: Gelu,
}
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
/// Custom settings for the display of the module.
/// If `None` is returned, the default settings will be used.
fn custom_settings(&self) -> Option<burn::module::DisplaySettings> {
DisplaySettings::new()
// Will show all attributes (default is false)
.with_show_all_attributes(false)
// Will show each attribute on a new line (default is true)
.with_new_line_after_attribute(true)
// Will show the number of parameters (default is true)
.with_show_num_parameters(true)
// Will indent by 2 spaces (default is 2)
.with_indentation_size(2)
// Will show the parameter ID (default is false)
.with_show_param_id(false)
// Convenience method to wrap settings in Some()
.optional()
}
/// Custom content to be displayed.
/// If `None` is returned, the default content will be used
/// (all attributes of the module)
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("linear_inner", &self.linear_inner)
.add("linear_outer", &self.linear_outer)
.add("anything", "anything_else")
.optional()
}
}
```
@ -107,7 +162,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### General
| Burn API | PyTorch Equivalent |
|----------------|-----------------------------------------------|
| -------------- | --------------------------------------------- |
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
| `Dropout` | `nn.Dropout` |
| `Embedding` | `nn.Embedding` |
@ -125,7 +180,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Convolutions
| Burn API | PyTorch Equivalent |
|-------------------|----------------------|
| ----------------- | -------------------- |
| `Conv1d` | `nn.Conv1d` |
| `Conv2d` | `nn.Conv2d` |
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
@ -134,7 +189,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Pooling
| Burn API | PyTorch Equivalent |
|---------------------|------------------------|
| ------------------- | ---------------------- |
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
| `AvgPool1d` | `nn.AvgPool1d` |
@ -145,7 +200,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### RNNs
| Burn API | PyTorch Equivalent |
|------------------|------------------------|
| ---------------- | ---------------------- |
| `Gru` | `nn.GRU` |
| `Lstm`/`BiLstm` | `nn.LSTM` |
| `GateController` | _No direct equivalent_ |
@ -153,7 +208,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Transformer
| Burn API | PyTorch Equivalent |
|----------------------|-------------------------|
| -------------------- | ----------------------- |
| `MultiHeadAttention` | `nn.MultiheadAttention` |
| `TransformerDecoder` | `nn.TransformerDecoder` |
| `TransformerEncoder` | `nn.TransformerEncoder` |
@ -163,7 +218,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Loss
| Burn API | PyTorch Equivalent |
|--------------------|-----------------------|
| ------------------ | --------------------- |
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |

View File

@ -4,6 +4,21 @@ This example corresponds to the [book's guide](https://burn.dev/book/basic-workf
## Example Usage
### Training
```sh
cargo run --example guide
```
cargo run --bin train --release
```
### Inference
```sh
cargo run --bin infer --release
```
### Print the model
```sh
cargo run --bin print --release
```

View File

@ -10,7 +10,7 @@ use std::process::Command;
fn main() {
Command::new("cargo")
.args(["run", "--bin", "guide"])
.args(["run", "--bin", "train", "--release"])
.status()
.expect("guide example should run");
}

View File

@ -0,0 +1,20 @@
use burn::{backend::Wgpu, data::dataset::Dataset};
use guide::inference;
fn main() {
type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts are saved in this directory
let artifact_dir = "/tmp/guide";
// Infer the model
inference::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
.unwrap(),
);
}

View File

@ -0,0 +1,11 @@
use burn::backend::Wgpu;
use guide::model::ModelConfig;
fn main() {
type MyBackend = Wgpu<f32, i32>;
let device = Default::default();
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);
println!("{}", model);
}

View File

@ -1,27 +1,33 @@
mod data;
mod inference;
mod model;
mod training;
use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
use guide::{
inference,
model::ModelConfig,
training::{self, TrainingConfig},
};
fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default();
// All the training artifacts will be saved in this directory
let artifact_dir = "/tmp/guide";
crate::training::train::<MyAutodiffBackend>(
// Train the model
training::train::<MyAutodiffBackend>(
artifact_dir,
TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(),
);
crate::inference::infer::<MyBackend>(
// Infer the model
inference::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()

View File

@ -7,10 +7,10 @@ use burn::{
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model");
.expect("Config should exist for the model; run train first");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
.expect("Trained model should exist; run train first");
let model: Model<B> = config.model.init(&device).load_record(record);