From 7baa33bdaaaa442a30184aceaa205dcfa8fa5923 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 29 Aug 2024 08:58:51 -0400 Subject: [PATCH] Fix target convert in batcher and align guide imports (#2215) * Fix target convert in batcher * Align hidden code in training and update loss to use config * Align imports with example * Remove unused import and fix guide->crate --- burn-book/src/basic-workflow/data.md | 10 +++-- burn-book/src/basic-workflow/inference.md | 13 +++--- burn-book/src/basic-workflow/model.md | 2 +- burn-book/src/basic-workflow/training.md | 52 ++++++++++------------- 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/burn-book/src/basic-workflow/data.md b/burn-book/src/basic-workflow/data.md index 4e3683c21..dab324e95 100644 --- a/burn-book/src/basic-workflow/data.md +++ b/burn-book/src/basic-workflow/data.md @@ -79,10 +79,12 @@ impl Batcher> for MnistBatcher { let targets = items .iter() - .map(|item| Tensor::::from_data( - TensorData::from([(item.label as i64).elem()]), - &self.device - )) + .map(|item| { + Tensor::::from_data( + [(item.label as i64).elem::()], + &self.device, + ) + }) .collect(); let images = Tensor::cat(images, 0).to_device(&self.device); diff --git a/burn-book/src/basic-workflow/inference.md b/burn-book/src/basic-workflow/inference.md index 1195055ae..88ae9afc7 100644 --- a/burn-book/src/basic-workflow/inference.md +++ b/burn-book/src/basic-workflow/inference.md @@ -10,15 +10,12 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic load our trained model. ```rust , ignore -# use burn::{ -# config::Config, -# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, -# module::Module, -# record::{CompactRecorder, Recorder}, -# tensor::backend::Backend, -# }; -# # use crate::{data::MnistBatcher, training::TrainingConfig}; +# use burn::{ +# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, +# prelude::*, +# record::{CompactRecorder, Recorder}, +# }; # pub fn infer(artifact_dir: &str, device: B::Device, item: MnistItem) { let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index bfead129e..e8471fa12 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -221,8 +221,8 @@ impl ModelConfig { At a glance, you can view the model configuration by printing the model instance: ```rust , ignore +use crate::model::ModelConfig; use burn::backend::Wgpu; -use guide::model::ModelConfig; fn main() { type MyBackend = Wgpu; diff --git a/burn-book/src/basic-workflow/training.md b/burn-book/src/basic-workflow/training.md index cc40a6e3e..6705beed1 100644 --- a/burn-book/src/basic-workflow/training.md +++ b/burn-book/src/basic-workflow/training.md @@ -39,7 +39,9 @@ impl Model { targets: Tensor, ) -> ClassificationOutput { let output = self.forward(images); - let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), targets.clone()); ClassificationOutput::new(loss, output, targets) } @@ -60,28 +62,23 @@ Moving forward, we will proceed with the implementation of both the training and for our model. ```rust , ignore +# use crate::{ +# data::{MnistBatch, MnistBatcher}, +# model::{Model, ModelConfig}, +# }; # use burn::{ -# config::Config, # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, -# module::Module, -# nn::loss::CrossEntropyLoss, +# nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, +# prelude::*, # record::CompactRecorder, -# tensor::{ -# backend::{AutodiffBackend, Backend}, -# Int, Tensor, -# }, +# tensor::backend::AutodiffBackend, # train::{ # metric::{AccuracyMetric, LossMetric}, # ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, # }, # }; # -# use crate::{ -# data::{MnistBatch, MnistBatcher}, -# model::{Model, ModelConfig}, -# }; -# # impl Model { # pub fn forward_classification( # &self, @@ -89,8 +86,9 @@ for our model. # targets: Tensor, # ) -> ClassificationOutput { # let output = self.forward(images); -# let loss = -# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); +# let loss = CrossEntropyLossConfig::new() +# .init(&output.device()) +# .forward(output.clone(), targets.clone()); # # ClassificationOutput::new(loss, output, targets) # } @@ -147,28 +145,23 @@ Book. Let us move on to establishing the practical training configuration. ```rust , ignore +# use crate::{ +# data::{MnistBatch, MnistBatcher}, +# model::{Model, ModelConfig}, +# }; # use burn::{ -# config::Config, # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, -# module::Module, -# nn::loss::CrossEntropyLoss, +# nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, +# prelude::*, # record::CompactRecorder, -# tensor::{ -# backend::{AutodiffBackend, Backend}, -# Int, Tensor, -# }, +# tensor::backend::AutodiffBackend, # train::{ # metric::{AccuracyMetric, LossMetric}, # ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, # }, # }; # -# use crate::{ -# data::{MnistBatch, MnistBatcher}, -# model::{Model, ModelConfig}, -# }; -# # impl Model { # pub fn forward_classification( # &self, @@ -176,8 +169,9 @@ Let us move on to establishing the practical training configuration. # targets: Tensor, # ) -> ClassificationOutput { # let output = self.forward(images); -# let loss = -# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); +# let loss = CrossEntropyLossConfig::new() +# .init(&output.device()) +# .forward(output.clone(), targets.clone()); # # ClassificationOutput::new(loss, output, targets) # }