From 560d77d1549fab2ff73526937b398fbe22f61d2c Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 18 Jun 2024 16:45:38 -0400 Subject: [PATCH] Doc: Improve module to_device/fork docs (#1901) --- crates/burn-core/src/module/base.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/crates/burn-core/src/module/base.rs b/crates/burn-core/src/module/base.rs index 6da8ea1e8..c34c155e8 100644 --- a/crates/burn-core/src/module/base.rs +++ b/crates/burn-core/src/module/base.rs @@ -97,17 +97,18 @@ pub trait Module: Clone + Send + core::fmt::Debug { /// /// # Notes /// - /// This is similar to [to_device](Module::to_device), but it ensures the module will - /// have its own autodiff graph. + /// This is similar to [to_device](Module::to_device), but it ensures the output module on the + /// new device will have its own autodiff graph. fn fork(self, device: &B::Device) -> Self; /// Move the module and all of its sub-modules to the given device. /// /// # Warnings /// - /// The device operations will be registered in the autodiff graph. Therefore, be sure to call - /// backward only one time even if you have the same module on multiple devices. If you want to - /// call backward multiple times, look into using [fork](Module::fork) instead. + /// The operation supports autodiff and it will be registered when activated. However, this may + /// not be what you want. The output model will be an intermediary model, meaning that you + /// can't optimize it with gradient descent. If you want to optimize the output network on the + /// target device, use [fork](Module::fork) instead. fn to_device(self, device: &B::Device) -> Self; /// Each tensor in the module tree will not require grad.