mirror of https://github.com/tracel-ai/burn.git
Fix target convert in batcher and align guide imports (#2215)
* Fix target convert in batcher * Align hidden code in training and update loss to use config * Align imports with example * Remove unused import and fix guide->crate
This commit is contained in:
parent
3664c6ac69
commit
0e445a9680
|
@ -79,10 +79,12 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||||
|
|
||||||
let targets = items
|
let targets = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| Tensor::<B, 1, Int>::from_data(
|
.map(|item| {
|
||||||
TensorData::from([(item.label as i64).elem()]),
|
Tensor::<B, 1, Int>::from_data(
|
||||||
&self.device
|
[(item.label as i64).elem::<B::IntElem>()],
|
||||||
))
|
&self.device,
|
||||||
|
)
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let images = Tensor::cat(images, 0).to_device(&self.device);
|
let images = Tensor::cat(images, 0).to_device(&self.device);
|
||||||
|
|
|
@ -10,15 +10,12 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic
|
||||||
load our trained model.
|
load our trained model.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
# use burn::{
|
|
||||||
# config::Config,
|
|
||||||
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
|
||||||
# module::Module,
|
|
||||||
# record::{CompactRecorder, Recorder},
|
|
||||||
# tensor::backend::Backend,
|
|
||||||
# };
|
|
||||||
#
|
|
||||||
# use crate::{data::MnistBatcher, training::TrainingConfig};
|
# use crate::{data::MnistBatcher, training::TrainingConfig};
|
||||||
|
# use burn::{
|
||||||
|
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
|
# prelude::*,
|
||||||
|
# record::{CompactRecorder, Recorder},
|
||||||
|
# };
|
||||||
#
|
#
|
||||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||||
|
|
|
@ -221,8 +221,8 @@ impl ModelConfig {
|
||||||
At a glance, you can view the model configuration by printing the model instance:
|
At a glance, you can view the model configuration by printing the model instance:
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
use crate::model::ModelConfig;
|
||||||
use burn::backend::Wgpu;
|
use burn::backend::Wgpu;
|
||||||
use guide::model::ModelConfig;
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
type MyBackend = Wgpu<f32, i32>;
|
type MyBackend = Wgpu<f32, i32>;
|
||||||
|
|
|
@ -39,7 +39,9 @@ impl<B: Backend> Model<B> {
|
||||||
targets: Tensor<B, 1, Int>,
|
targets: Tensor<B, 1, Int>,
|
||||||
) -> ClassificationOutput<B> {
|
) -> ClassificationOutput<B> {
|
||||||
let output = self.forward(images);
|
let output = self.forward(images);
|
||||||
let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
let loss = CrossEntropyLossConfig::new()
|
||||||
|
.init(&output.device())
|
||||||
|
.forward(output.clone(), targets.clone());
|
||||||
|
|
||||||
ClassificationOutput::new(loss, output, targets)
|
ClassificationOutput::new(loss, output, targets)
|
||||||
}
|
}
|
||||||
|
@ -60,28 +62,23 @@ Moving forward, we will proceed with the implementation of both the training and
|
||||||
for our model.
|
for our model.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use crate::{
|
||||||
|
# data::{MnistBatch, MnistBatcher},
|
||||||
|
# model::{Model, ModelConfig},
|
||||||
|
# };
|
||||||
# use burn::{
|
# use burn::{
|
||||||
# config::Config,
|
|
||||||
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||||
# module::Module,
|
# nn::loss::CrossEntropyLossConfig,
|
||||||
# nn::loss::CrossEntropyLoss,
|
|
||||||
# optim::AdamConfig,
|
# optim::AdamConfig,
|
||||||
|
# prelude::*,
|
||||||
# record::CompactRecorder,
|
# record::CompactRecorder,
|
||||||
# tensor::{
|
# tensor::backend::AutodiffBackend,
|
||||||
# backend::{AutodiffBackend, Backend},
|
|
||||||
# Int, Tensor,
|
|
||||||
# },
|
|
||||||
# train::{
|
# train::{
|
||||||
# metric::{AccuracyMetric, LossMetric},
|
# metric::{AccuracyMetric, LossMetric},
|
||||||
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||||
# },
|
# },
|
||||||
# };
|
# };
|
||||||
#
|
#
|
||||||
# use crate::{
|
|
||||||
# data::{MnistBatch, MnistBatcher},
|
|
||||||
# model::{Model, ModelConfig},
|
|
||||||
# };
|
|
||||||
#
|
|
||||||
# impl<B: Backend> Model<B> {
|
# impl<B: Backend> Model<B> {
|
||||||
# pub fn forward_classification(
|
# pub fn forward_classification(
|
||||||
# &self,
|
# &self,
|
||||||
|
@ -89,8 +86,9 @@ for our model.
|
||||||
# targets: Tensor<B, 1, Int>,
|
# targets: Tensor<B, 1, Int>,
|
||||||
# ) -> ClassificationOutput<B> {
|
# ) -> ClassificationOutput<B> {
|
||||||
# let output = self.forward(images);
|
# let output = self.forward(images);
|
||||||
# let loss =
|
# let loss = CrossEntropyLossConfig::new()
|
||||||
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
# .init(&output.device())
|
||||||
|
# .forward(output.clone(), targets.clone());
|
||||||
#
|
#
|
||||||
# ClassificationOutput::new(loss, output, targets)
|
# ClassificationOutput::new(loss, output, targets)
|
||||||
# }
|
# }
|
||||||
|
@ -147,28 +145,23 @@ Book.
|
||||||
Let us move on to establishing the practical training configuration.
|
Let us move on to establishing the practical training configuration.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
|
# use crate::{
|
||||||
|
# data::{MnistBatch, MnistBatcher},
|
||||||
|
# model::{Model, ModelConfig},
|
||||||
|
# };
|
||||||
# use burn::{
|
# use burn::{
|
||||||
# config::Config,
|
|
||||||
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||||
# module::Module,
|
# nn::loss::CrossEntropyLossConfig,
|
||||||
# nn::loss::CrossEntropyLoss,
|
|
||||||
# optim::AdamConfig,
|
# optim::AdamConfig,
|
||||||
|
# prelude::*,
|
||||||
# record::CompactRecorder,
|
# record::CompactRecorder,
|
||||||
# tensor::{
|
# tensor::backend::AutodiffBackend,
|
||||||
# backend::{AutodiffBackend, Backend},
|
|
||||||
# Int, Tensor,
|
|
||||||
# },
|
|
||||||
# train::{
|
# train::{
|
||||||
# metric::{AccuracyMetric, LossMetric},
|
# metric::{AccuracyMetric, LossMetric},
|
||||||
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||||
# },
|
# },
|
||||||
# };
|
# };
|
||||||
#
|
#
|
||||||
# use crate::{
|
|
||||||
# data::{MnistBatch, MnistBatcher},
|
|
||||||
# model::{Model, ModelConfig},
|
|
||||||
# };
|
|
||||||
#
|
|
||||||
# impl<B: Backend> Model<B> {
|
# impl<B: Backend> Model<B> {
|
||||||
# pub fn forward_classification(
|
# pub fn forward_classification(
|
||||||
# &self,
|
# &self,
|
||||||
|
@ -176,8 +169,9 @@ Let us move on to establishing the practical training configuration.
|
||||||
# targets: Tensor<B, 1, Int>,
|
# targets: Tensor<B, 1, Int>,
|
||||||
# ) -> ClassificationOutput<B> {
|
# ) -> ClassificationOutput<B> {
|
||||||
# let output = self.forward(images);
|
# let output = self.forward(images);
|
||||||
# let loss =
|
# let loss = CrossEntropyLossConfig::new()
|
||||||
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
# .init(&output.device())
|
||||||
|
# .forward(output.clone(), targets.clone());
|
||||||
#
|
#
|
||||||
# ClassificationOutput::new(loss, output, targets)
|
# ClassificationOutput::new(loss, output, targets)
|
||||||
# }
|
# }
|
||||||
|
|
Loading…
Reference in New Issue