fix(book): add missing device parameter to mode.init() (#1302)

This commit is contained in:
Jakub 2024-02-13 15:34:03 +01:00 committed by GitHub
parent 938a9d00b3
commit a68b494531
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -139,10 +139,10 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
.metric_train_numeric(LossMetric::new()) .metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new()) .with_file_checkpointer(CompactRecorder::new())
.devices(vec![device]) .devices(vec![device.clone()])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)
.build( .build(
config.model.init::<B>(), config.model.init::<B>(&device),
config.optimizer.init(), config.optimizer.init(),
config.learning_rate, config.learning_rate,
); );