Detach all grads during backprop. (#1243)
* Detach all grads during backprop. * Add an environment variable to select the backprop behavior. * Update the comment.
This commit is contained in:
parent
928a9d906e
commit
60fdab4e17
|
@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
|||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
|
@ -155,10 +166,16 @@ impl Tensor {
|
|||
if node.is_variable() {
|
||||
continue;
|
||||
}
|
||||
let grad = grads.remove(node).unwrap();
|
||||
// TODO: We should perform all these operations in place (or at least not track the
|
||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||
// this is out of scope.
|
||||
let grad = grads
|
||||
.remove(node)
|
||||
.expect("candle internal error - grad not populated");
|
||||
// https://github.com/huggingface/candle/issues/1241
|
||||
// Ideally, we would make these operations in place where possible to ensure that we
|
||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
|
|
Loading…
Reference in New Issue