mirror of https://github.com/tracel-ai/burn.git
Fix some issues in Burn book (#1042)
This commit is contained in:
parent
cb8d4dac62
commit
52811f9938
|
@ -15,7 +15,7 @@ impression that Burn operates at a high level over the backend layer. However, m
|
|||
explicit instead of being chosen via a compilation flag was a thoughtful design decision. This
|
||||
explicitness does not imply that all backends must be identical; rather, it offers a great deal of
|
||||
flexibility when composing backends. The autodifferentiation backend trait (see
|
||||
[autodiff section](../building-blocks/autodiff)) is an example of how the backend trait has been
|
||||
[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has been
|
||||
extended to enable gradient computation with backpropagation. Furthermore, this design allows you to
|
||||
create your own backend extension. To achieve this, you need to design your own backend trait
|
||||
specifying which functions should be supported.
|
||||
|
|
|
@ -34,11 +34,11 @@ Now let's create a simple `infer` method in which we will load our trained model
|
|||
|
||||
```rust , ignore
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
||||
let config =
|
||||
TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists");
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.expect("Failed to load trained model");
|
||||
.expect("Trained model should exist");
|
||||
|
||||
let model = config.model.init_with::<B>(record).to_device(&device);
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
config
|
||||
.save(format!("{artifact_dir}/config.json"))
|
||||
.expect("Save without error");
|
||||
.expect("Config should be saved successfully");
|
||||
|
||||
B::seed(config.seed);
|
||||
|
||||
|
@ -128,7 +128,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
|
||||
model_trained
|
||||
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||
.expect("Failed to save trained model");
|
||||
.expect("Trained model should be saved successfully");
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -41,8 +41,8 @@ fn main() {
|
|||
|
||||
## Good practices
|
||||
|
||||
The interest of the Config pattern is to be able to easily create instances, factoried from this
|
||||
config. In that optic, initialization methods should be implemented on the config struct.
|
||||
By using the Config pattern it is easy to create instances from this
|
||||
config. Therefore, initialization methods should be implemented on the config struct.
|
||||
|
||||
```rust, ignore
|
||||
impl MyModuleConfig {
|
||||
|
@ -70,5 +70,5 @@ impl MyModuleConfig {
|
|||
Then we could add this line to the above `main`:
|
||||
|
||||
```rust, ignore
|
||||
let my_module = config.init()
|
||||
let my_module = config.init();
|
||||
```
|
||||
|
|
|
@ -32,7 +32,7 @@ The learner builder provides numerous options when it comes to configurations.
|
|||
| Devices | Set the devices to be used |
|
||||
| Checkpoint | Restart training from a checkpoint |
|
||||
|
||||
When the builder is configured at your liking, you can them move forward to build the learner. The
|
||||
When the builder is configured at your liking, you can then move forward to build the learner. The
|
||||
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note
|
||||
that the latter can be a simple float if you want it to be constant during training.
|
||||
|
||||
|
|
|
@ -208,7 +208,7 @@ This will result in the following compilation error:
|
|||
unconstrained type parameter [E0207]
|
||||
```
|
||||
|
||||
To resolve this issue, you have two options. The first one is to make your function is generic over
|
||||
To resolve this issue, you have two options. The first one is to make your function generic over
|
||||
the backend and add your trait constraint within its definition:
|
||||
|
||||
```rust, ignore
|
||||
|
|
|
@ -82,7 +82,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn build(&mut self, graph: &mut Graph<B>, mode: ExecutionMode) -> BuildAction<'_, B> {
|
||||
fn build(&mut self, graph: &Graph<B>, mode: ExecutionMode) -> BuildAction<'_, B> {
|
||||
// When we are executing with the new ops mode, we need to register the last ops of the
|
||||
// graph even when there is no skipped operation.
|
||||
let offset = match mode {
|
||||
|
@ -123,7 +123,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, graph: &mut Graph<B>) {
|
||||
fn reset(&mut self, graph: &Graph<B>) {
|
||||
for ops in self.optimizations.iter_mut() {
|
||||
ops.reset();
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
|||
|
||||
fn cache<'a>(
|
||||
&'a mut self,
|
||||
graph: &mut Graph<B>,
|
||||
graph: &Graph<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> CacheResult<'a, Box<dyn Optimization<B>>> {
|
||||
let (graph, next_ops) = Self::split_relative_graph_ref(graph, mode);
|
||||
|
|
|
@ -316,7 +316,7 @@ fn same_as_input(node: &mut Node) {
|
|||
}
|
||||
|
||||
/// Temporary pass-through stub for dimension inference so that we can export the IR model.
|
||||
fn temporary_pass_through_stub(node: &mut Node) {
|
||||
fn temporary_pass_through_stub(node: &Node) {
|
||||
log::warn!(
|
||||
"Must implement dimension inference for {:?}",
|
||||
node.node_type
|
||||
|
|
|
@ -9,11 +9,11 @@ use burn::{
|
|||
};
|
||||
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
||||
let config =
|
||||
TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists");
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.expect("Failed to load trained model");
|
||||
.expect("Trained model should exist");
|
||||
|
||||
let model = config.model.init_with::<B>(record).to_device(&device);
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
config
|
||||
.save(format!("{artifact_dir}/config.json"))
|
||||
.expect("Save without error");
|
||||
.expect("Config should be saved successfully");
|
||||
|
||||
B::seed(config.seed);
|
||||
|
||||
|
@ -105,5 +105,5 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
|
||||
model_trained
|
||||
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||
.expect("Failed to save trained model");
|
||||
.expect("Trained model should be saved successfully");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue