mirror of https://github.com/tracel-ai/burn.git
fix(book): add missing device parameter to mode.init() (#1302)
This commit is contained in:
parent
938a9d00b3
commit
a68b494531
|
@ -139,10 +139,10 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(
|
||||
config.model.init::<B>(),
|
||||
config.model.init::<B>(&device),
|
||||
config.optimizer.init(),
|
||||
config.learning_rate,
|
||||
);
|
||||
|
|
Loading…
Reference in New Issue