mirror of https://github.com/tracel-ai/burn.git
Fix target convert in batcher and align guide imports (#2215)
* Fix target convert in batcher * Align hidden code in training and update loss to use config * Align imports with example * Remove unused import and fix guide->crate
This commit is contained in:
parent
a88c69af4a
commit
7baa33bdaa
|
@ -79,10 +79,12 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
|||
|
||||
let targets = items
|
||||
.iter()
|
||||
.map(|item| Tensor::<B, 1, Int>::from_data(
|
||||
TensorData::from([(item.label as i64).elem()]),
|
||||
&self.device
|
||||
))
|
||||
.map(|item| {
|
||||
Tensor::<B, 1, Int>::from_data(
|
||||
[(item.label as i64).elem::<B::IntElem>()],
|
||||
&self.device,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0).to_device(&self.device);
|
||||
|
|
|
@ -10,15 +10,12 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic
|
|||
load our trained model.
|
||||
|
||||
```rust , ignore
|
||||
# use burn::{
|
||||
# config::Config,
|
||||
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
# module::Module,
|
||||
# record::{CompactRecorder, Recorder},
|
||||
# tensor::backend::Backend,
|
||||
# };
|
||||
#
|
||||
# use crate::{data::MnistBatcher, training::TrainingConfig};
|
||||
# use burn::{
|
||||
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
# prelude::*,
|
||||
# record::{CompactRecorder, Recorder},
|
||||
# };
|
||||
#
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
|
|
|
@ -221,8 +221,8 @@ impl ModelConfig {
|
|||
At a glance, you can view the model configuration by printing the model instance:
|
||||
|
||||
```rust , ignore
|
||||
use crate::model::ModelConfig;
|
||||
use burn::backend::Wgpu;
|
||||
use guide::model::ModelConfig;
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
|
|
|
@ -39,7 +39,9 @@ impl<B: Backend> Model<B> {
|
|||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let output = self.forward(images);
|
||||
let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
.init(&output.device())
|
||||
.forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput::new(loss, output, targets)
|
||||
}
|
||||
|
@ -60,28 +62,23 @@ Moving forward, we will proceed with the implementation of both the training and
|
|||
for our model.
|
||||
|
||||
```rust , ignore
|
||||
# use crate::{
|
||||
# data::{MnistBatch, MnistBatcher},
|
||||
# model::{Model, ModelConfig},
|
||||
# };
|
||||
# use burn::{
|
||||
# config::Config,
|
||||
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
# module::Module,
|
||||
# nn::loss::CrossEntropyLoss,
|
||||
# nn::loss::CrossEntropyLossConfig,
|
||||
# optim::AdamConfig,
|
||||
# prelude::*,
|
||||
# record::CompactRecorder,
|
||||
# tensor::{
|
||||
# backend::{AutodiffBackend, Backend},
|
||||
# Int, Tensor,
|
||||
# },
|
||||
# tensor::backend::AutodiffBackend,
|
||||
# train::{
|
||||
# metric::{AccuracyMetric, LossMetric},
|
||||
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
# },
|
||||
# };
|
||||
#
|
||||
# use crate::{
|
||||
# data::{MnistBatch, MnistBatcher},
|
||||
# model::{Model, ModelConfig},
|
||||
# };
|
||||
#
|
||||
# impl<B: Backend> Model<B> {
|
||||
# pub fn forward_classification(
|
||||
# &self,
|
||||
|
@ -89,8 +86,9 @@ for our model.
|
|||
# targets: Tensor<B, 1, Int>,
|
||||
# ) -> ClassificationOutput<B> {
|
||||
# let output = self.forward(images);
|
||||
# let loss =
|
||||
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
||||
# let loss = CrossEntropyLossConfig::new()
|
||||
# .init(&output.device())
|
||||
# .forward(output.clone(), targets.clone());
|
||||
#
|
||||
# ClassificationOutput::new(loss, output, targets)
|
||||
# }
|
||||
|
@ -147,28 +145,23 @@ Book.
|
|||
Let us move on to establishing the practical training configuration.
|
||||
|
||||
```rust , ignore
|
||||
# use crate::{
|
||||
# data::{MnistBatch, MnistBatcher},
|
||||
# model::{Model, ModelConfig},
|
||||
# };
|
||||
# use burn::{
|
||||
# config::Config,
|
||||
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
# module::Module,
|
||||
# nn::loss::CrossEntropyLoss,
|
||||
# nn::loss::CrossEntropyLossConfig,
|
||||
# optim::AdamConfig,
|
||||
# prelude::*,
|
||||
# record::CompactRecorder,
|
||||
# tensor::{
|
||||
# backend::{AutodiffBackend, Backend},
|
||||
# Int, Tensor,
|
||||
# },
|
||||
# tensor::backend::AutodiffBackend,
|
||||
# train::{
|
||||
# metric::{AccuracyMetric, LossMetric},
|
||||
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
# },
|
||||
# };
|
||||
#
|
||||
# use crate::{
|
||||
# data::{MnistBatch, MnistBatcher},
|
||||
# model::{Model, ModelConfig},
|
||||
# };
|
||||
#
|
||||
# impl<B: Backend> Model<B> {
|
||||
# pub fn forward_classification(
|
||||
# &self,
|
||||
|
@ -176,8 +169,9 @@ Let us move on to establishing the practical training configuration.
|
|||
# targets: Tensor<B, 1, Int>,
|
||||
# ) -> ClassificationOutput<B> {
|
||||
# let output = self.forward(images);
|
||||
# let loss =
|
||||
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
||||
# let loss = CrossEntropyLossConfig::new()
|
||||
# .init(&output.device())
|
||||
# .forward(output.clone(), targets.clone());
|
||||
#
|
||||
# ClassificationOutput::new(loss, output, targets)
|
||||
# }
|
||||
|
|
Loading…
Reference in New Issue