From 60fdab4e17d3e420f20610ec75df3deccd8e1f69 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 5 Nov 2023 14:07:41 +0100 Subject: [PATCH] Detach all grads during backprop. (#1243) * Detach all grads during backprop. * Add an environment variable to select the backprop behavior. * Update the comment. --- candle-core/src/backprop.rs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 1448a6f4..fc0c79a2 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result } } +thread_local! { + static CANDLE_GRAD_DO_NOT_DETACH: bool = { + match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + impl Tensor { /// Return all the nodes that lead to this value in a topologically sorted vec, the first /// elements having dependencies on the latter ones, e.g. the first element if any is the @@ -155,10 +166,16 @@ impl Tensor { if node.is_variable() { continue; } - let grad = grads.remove(node).unwrap(); - // TODO: We should perform all these operations in place (or at least not track the - // whole graph). The only drawback would be if we wanted to support grad of grad but - // this is out of scope. + let grad = grads + .remove(node) + .expect("candle internal error - grad not populated"); + // https://github.com/huggingface/candle/issues/1241 + // Ideally, we would make these operations in place where possible to ensure that we + // do not have to allocate too often. Here we just call `.detach` to avoid computing + // the backprop graph of the backprop itself. This would be an issue for second order + // derivatives but these are out of scope at the moment. + let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); + let grad = if do_not_detach { grad } else { grad.detach()? }; if let Some(op) = node.op() { match op { Op::Binary(lhs, rhs, BinaryOp::Add) => {