Fix missing device in custom training loop book example (#1606)

This commit is contained in:
Guillaume Lagrange 2024-04-12 10:34:04 -04:00 committed by GitHub
parent 23210f05f2
commit 06ce2b02d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -72,7 +72,8 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
// Implement our training loop.
for (iteration, batch) in dataloader_train.iter().enumerate() {
let output = model.forward(batch.images);
let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone());
let loss = CrossEntropyLoss::new(None, &output.device())
.forward(output.clone(), batch.targets.clone());
let accuracy = accuracy(output, batch.targets);
println!(
@ -97,7 +98,8 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
// Implement our validation loop.
for (iteration, batch) in dataloader_test.iter().enumerate() {
let output = model_valid.forward(batch.images);
let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone());
let loss = CrossEntropyLoss::new(None, &output.device())
.forward(output.clone(), batch.targets.clone());
let accuracy = accuracy(output, batch.targets);
println!(