Fix target convert in batcher and align guide imports (#2215)

* Fix target convert in batcher

* Align hidden code in training and update loss to use config

* Align imports with example

* Remove unused import and fix guide->crate
This commit is contained in:
Guillaume Lagrange 2024-08-29 08:58:51 -04:00
parent 3664c6ac69
commit 0e445a9680
4 changed files with 35 additions and 42 deletions

View File

@ -79,10 +79,12 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
let targets = items
.iter()
.map(|item| Tensor::<B, 1, Int>::from_data(
TensorData::from([(item.label as i64).elem()]),
&self.device
))
.map(|item| {
Tensor::<B, 1, Int>::from_data(
[(item.label as i64).elem::<B::IntElem>()],
&self.device,
)
})
.collect();
let images = Tensor::cat(images, 0).to_device(&self.device);

View File

@ -10,15 +10,12 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic
load our trained model.
```rust , ignore
# use burn::{
# config::Config,
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
# module::Module,
# record::{CompactRecorder, Recorder},
# tensor::backend::Backend,
# };
#
# use crate::{data::MnistBatcher, training::TrainingConfig};
# use burn::{
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
# prelude::*,
# record::{CompactRecorder, Recorder},
# };
#
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))

View File

@ -221,8 +221,8 @@ impl ModelConfig {
At a glance, you can view the model configuration by printing the model instance:
```rust , ignore
use crate::model::ModelConfig;
use burn::backend::Wgpu;
use guide::model::ModelConfig;
fn main() {
type MyBackend = Wgpu<f32, i32>;

View File

@ -39,7 +39,9 @@ impl<B: Backend> Model<B> {
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
let output = self.forward(images);
let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), targets.clone());
ClassificationOutput::new(loss, output, targets)
}
@ -60,28 +62,23 @@ Moving forward, we will proceed with the implementation of both the training and
for our model.
```rust , ignore
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
# use burn::{
# config::Config,
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
# module::Module,
# nn::loss::CrossEntropyLoss,
# nn::loss::CrossEntropyLossConfig,
# optim::AdamConfig,
# prelude::*,
# record::CompactRecorder,
# tensor::{
# backend::{AutodiffBackend, Backend},
# Int, Tensor,
# },
# tensor::backend::AutodiffBackend,
# train::{
# metric::{AccuracyMetric, LossMetric},
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
# },
# };
#
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
#
# impl<B: Backend> Model<B> {
# pub fn forward_classification(
# &self,
@ -89,8 +86,9 @@ for our model.
# targets: Tensor<B, 1, Int>,
# ) -> ClassificationOutput<B> {
# let output = self.forward(images);
# let loss =
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
# let loss = CrossEntropyLossConfig::new()
# .init(&output.device())
# .forward(output.clone(), targets.clone());
#
# ClassificationOutput::new(loss, output, targets)
# }
@ -147,28 +145,23 @@ Book.
Let us move on to establishing the practical training configuration.
```rust , ignore
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
# use burn::{
# config::Config,
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
# module::Module,
# nn::loss::CrossEntropyLoss,
# nn::loss::CrossEntropyLossConfig,
# optim::AdamConfig,
# prelude::*,
# record::CompactRecorder,
# tensor::{
# backend::{AutodiffBackend, Backend},
# Int, Tensor,
# },
# tensor::backend::AutodiffBackend,
# train::{
# metric::{AccuracyMetric, LossMetric},
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
# },
# };
#
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
#
# impl<B: Backend> Model<B> {
# pub fn forward_classification(
# &self,
@ -176,8 +169,9 @@ Let us move on to establishing the practical training configuration.
# targets: Tensor<B, 1, Int>,
# ) -> ClassificationOutput<B> {
# let output = self.forward(images);
# let loss =
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
# let loss = CrossEntropyLossConfig::new()
# .init(&output.device())
# .forward(output.clone(), targets.clone());
#
# ClassificationOutput::new(loss, output, targets)
# }