mirror of https://github.com/tracel-ai/burn.git
Fix missing device in custom training loop book example (#1606)
This commit is contained in:
parent
23210f05f2
commit
06ce2b02d6
|
@ -72,7 +72,8 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
// Implement our training loop.
|
||||
for (iteration, batch) in dataloader_train.iter().enumerate() {
|
||||
let output = model.forward(batch.images);
|
||||
let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone());
|
||||
let loss = CrossEntropyLoss::new(None, &output.device())
|
||||
.forward(output.clone(), batch.targets.clone());
|
||||
let accuracy = accuracy(output, batch.targets);
|
||||
|
||||
println!(
|
||||
|
@ -97,7 +98,8 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
// Implement our validation loop.
|
||||
for (iteration, batch) in dataloader_test.iter().enumerate() {
|
||||
let output = model_valid.forward(batch.images);
|
||||
let loss = CrossEntropyLoss::new(None).forward(output.clone(), batch.targets.clone());
|
||||
let loss = CrossEntropyLoss::new(None, &output.device())
|
||||
.forward(output.clone(), batch.targets.clone());
|
||||
let accuracy = accuracy(output, batch.targets);
|
||||
|
||||
println!(
|
||||
|
|
Loading…
Reference in New Issue