mirror of https://github.com/tracel-ai/burn.git
fix(book): add missing device parameter to mode.init() (#1302)
This commit is contained in:
parent
938a9d00b3
commit
a68b494531
|
@ -139,10 +139,10 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
||||||
.metric_train_numeric(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
.metric_valid_numeric(LossMetric::new())
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.with_file_checkpointer(CompactRecorder::new())
|
.with_file_checkpointer(CompactRecorder::new())
|
||||||
.devices(vec![device])
|
.devices(vec![device.clone()])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
.build(
|
.build(
|
||||||
config.model.init::<B>(),
|
config.model.init::<B>(&device),
|
||||||
config.optimizer.init(),
|
config.optimizer.init(),
|
||||||
config.learning_rate,
|
config.learning_rate,
|
||||||
);
|
);
|
||||||
|
|
Loading…
Reference in New Issue