mirror of https://github.com/tracel-ai/burn.git
Print module part3 - Update book (#1940)
* Update book example guide * Update Module book section on module display
This commit is contained in:
parent
3a9367de73
commit
6f2ba34382
|
@ -192,7 +192,7 @@ Next, we need to instantiate the model for training.
|
|||
# linear2: Linear<B>,
|
||||
# activation: Relu,
|
||||
# }
|
||||
#
|
||||
#
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ModelConfig {
|
||||
num_classes: usize,
|
||||
|
@ -217,6 +217,40 @@ impl ModelConfig {
|
|||
}
|
||||
```
|
||||
|
||||
|
||||
At a glance, you can view the model configuration by printing the model instance:
|
||||
|
||||
```rust , ignore
|
||||
use burn::backend::Wgpu;
|
||||
use guide::model::ModelConfig;
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
||||
let device = Default::default();
|
||||
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);
|
||||
|
||||
println!("{}", model);
|
||||
}
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```rust , ignore
|
||||
Model {
|
||||
conv1: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80}
|
||||
conv2: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168}
|
||||
pool: AdaptiveAvgPool2d {output_size: [8, 8]}
|
||||
dropout: Dropout {prob: 0.5}
|
||||
linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800}
|
||||
linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130}
|
||||
activation: Relu
|
||||
params: 531178
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🦀 References</strong></summary>
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
|
|||
These methods are available for all modules.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|-----------------------------------------|------------------------------------------|
|
||||
| --------------------------------------- | ---------------------------------------- |
|
||||
| `module.devices()` | N/A |
|
||||
| `module.fork(device)` | Similar to `module.to(device).detach()` |
|
||||
| `module.to_device(device)` | `module.to(device)` |
|
||||
|
@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
|
|||
autodiff support.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|------------------|--------------------|
|
||||
| ---------------- | ------------------ |
|
||||
| `module.valid()` | `module.eval()` |
|
||||
|
||||
## Visitor & Mapper
|
||||
|
@ -96,7 +96,62 @@ pub trait ModuleVisitor<B: Backend> {
|
|||
/// Module mapper trait.
|
||||
pub trait ModuleMapper<B: Backend> {
|
||||
/// Map a tensor in the module.
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) ->
|
||||
Tensor<B, D>;
|
||||
}
|
||||
```
|
||||
|
||||
## Module Display
|
||||
|
||||
Burn provides a simple way to display the structure of a module and its configuration at a glance.
|
||||
You can print the module to see its structure, which is useful for debugging and tracking changes
|
||||
across different versions of a module. (See the print output of the
|
||||
[Basic Workflow Model](../basic-workflow/model.md) example.)
|
||||
|
||||
To customize the display of a module, you can implement the `ModuleDisplay` trait for your module.
|
||||
This will change the default display settings for the module and its children. Note that
|
||||
`ModuleDisplay` is automatically implemented for all modules, but you can override it to customize
|
||||
the display by annotating the module with `#[module(custom_display)]`.
|
||||
|
||||
```rust
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PositionWiseFeedForward<B: Backend> {
|
||||
linear_inner: Linear<B>,
|
||||
linear_outer: Linear<B>,
|
||||
dropout: Dropout,
|
||||
gelu: Gelu,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
|
||||
/// Custom settings for the display of the module.
|
||||
/// If `None` is returned, the default settings will be used.
|
||||
fn custom_settings(&self) -> Option<burn::module::DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
// Will show all attributes (default is false)
|
||||
.with_show_all_attributes(false)
|
||||
// Will show each attribute on a new line (default is true)
|
||||
.with_new_line_after_attribute(true)
|
||||
// Will show the number of parameters (default is true)
|
||||
.with_show_num_parameters(true)
|
||||
// Will indent by 2 spaces (default is 2)
|
||||
.with_indentation_size(2)
|
||||
// Will show the parameter ID (default is false)
|
||||
.with_show_param_id(false)
|
||||
// Convenience method to wrap settings in Some()
|
||||
.optional()
|
||||
}
|
||||
|
||||
/// Custom content to be displayed.
|
||||
/// If `None` is returned, the default content will be used
|
||||
/// (all attributes of the module)
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("linear_inner", &self.linear_inner)
|
||||
.add("linear_outer", &self.linear_outer)
|
||||
.add("anything", "anything_else")
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -107,7 +162,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### General
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------|-----------------------------------------------|
|
||||
| -------------- | --------------------------------------------- |
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
|
@ -125,7 +180,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Convolutions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|-------------------|----------------------|
|
||||
| ----------------- | -------------------- |
|
||||
| `Conv1d` | `nn.Conv1d` |
|
||||
| `Conv2d` | `nn.Conv2d` |
|
||||
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
|
||||
|
@ -134,7 +189,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Pooling
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|---------------------|------------------------|
|
||||
| ------------------- | ---------------------- |
|
||||
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
|
||||
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
|
||||
| `AvgPool1d` | `nn.AvgPool1d` |
|
||||
|
@ -145,7 +200,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### RNNs
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|------------------|------------------------|
|
||||
| ---------------- | ---------------------- |
|
||||
| `Gru` | `nn.GRU` |
|
||||
| `Lstm`/`BiLstm` | `nn.LSTM` |
|
||||
| `GateController` | _No direct equivalent_ |
|
||||
|
@ -153,7 +208,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Transformer
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------------|-------------------------|
|
||||
| -------------------- | ----------------------- |
|
||||
| `MultiHeadAttention` | `nn.MultiheadAttention` |
|
||||
| `TransformerDecoder` | `nn.TransformerDecoder` |
|
||||
| `TransformerEncoder` | `nn.TransformerEncoder` |
|
||||
|
@ -163,7 +218,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Loss
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------|-----------------------|
|
||||
| ------------------ | --------------------- |
|
||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||
| `MseLoss` | `nn.MSELoss` |
|
||||
| `HuberLoss` | `nn.HuberLoss` |
|
||||
|
|
|
@ -4,6 +4,21 @@ This example corresponds to the [book's guide](https://burn.dev/book/basic-workf
|
|||
|
||||
## Example Usage
|
||||
|
||||
|
||||
### Training
|
||||
|
||||
```sh
|
||||
cargo run --example guide
|
||||
```
|
||||
cargo run --bin train --release
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
```sh
|
||||
cargo run --bin infer --release
|
||||
```
|
||||
|
||||
### Print the model
|
||||
|
||||
```sh
|
||||
cargo run --bin print --release
|
||||
```
|
||||
|
|
|
@ -10,7 +10,7 @@ use std::process::Command;
|
|||
|
||||
fn main() {
|
||||
Command::new("cargo")
|
||||
.args(["run", "--bin", "guide"])
|
||||
.args(["run", "--bin", "train", "--release"])
|
||||
.status()
|
||||
.expect("guide example should run");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
use burn::{backend::Wgpu, data::dataset::Dataset};
|
||||
use guide::inference;
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts are saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
|
||||
// Infer the model
|
||||
inference::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
use burn::backend::Wgpu;
|
||||
use guide::model::ModelConfig;
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
||||
let device = Default::default();
|
||||
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);
|
||||
|
||||
println!("{}", model);
|
||||
}
|
|
@ -1,27 +1,33 @@
|
|||
mod data;
|
||||
mod inference;
|
||||
mod model;
|
||||
mod training;
|
||||
|
||||
use crate::{model::ModelConfig, training::TrainingConfig};
|
||||
use burn::{
|
||||
backend::{Autodiff, Wgpu},
|
||||
data::dataset::Dataset,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
use guide::{
|
||||
inference,
|
||||
model::ModelConfig,
|
||||
training::{self, TrainingConfig},
|
||||
};
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
// Create a default Wgpu device
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
// All the training artifacts will be saved in this directory
|
||||
let artifact_dir = "/tmp/guide";
|
||||
crate::training::train::<MyAutodiffBackend>(
|
||||
|
||||
// Train the model
|
||||
training::train::<MyAutodiffBackend>(
|
||||
artifact_dir,
|
||||
TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
|
||||
device.clone(),
|
||||
);
|
||||
crate::inference::infer::<MyBackend>(
|
||||
|
||||
// Infer the model
|
||||
inference::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
|
@ -7,10 +7,10 @@ use burn::{
|
|||
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
.expect("Config should exist for the model; run train first");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into(), &device)
|
||||
.expect("Trained model should exist");
|
||||
.expect("Trained model should exist; run train first");
|
||||
|
||||
let model: Model<B> = config.model.init(&device).load_record(record);
|
||||
|
||||
|
|
Loading…
Reference in New Issue