From 0a1f107e45e7837100070f1840e4e8a6131df539 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 25 Apr 2024 10:46:14 -0400 Subject: [PATCH] Add test --- Cargo.lock | 8 -------- crates/burn-autodiff/src/tests/maxmin.rs | 19 ++++++++++++++++++- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0b4512a2d..f025a3c18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3712,14 +3712,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "refactor" -version = "0.14.0" -dependencies = [ - "burn", - "serde", -] - [[package]] name = "regex" version = "1.10.4" diff --git a/crates/burn-autodiff/src/tests/maxmin.rs b/crates/burn-autodiff/src/tests/maxmin.rs index 1e371eecc..e51f74223 100644 --- a/crates/burn-autodiff/src/tests/maxmin.rs +++ b/crates/burn-autodiff/src/tests/maxmin.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(ad_maxmin)] mod tests { use super::*; - use burn_tensor::Data; + use burn_tensor::{Data, Int, Tensor}; #[test] fn should_diff_max_dim() { @@ -48,4 +48,21 @@ mod tests { .to_data() .assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5); } + + #[test] + fn test_max_dim_complex() { + let device = Default::default(); + let a: Vec = vec![0.0, 0.0]; + let b = [0, 0]; + let b: Tensor = + Tensor::from_data(Data::from(b.as_slice()), &device).reshape([2, 1]); + let a = Tensor::from_data(Data::from(a.as_slice()), &device) + .reshape([2, 1]) + .require_grad(); + + let loss = a.gather(1, b); + let loss = loss.clone().max_dim(0) + loss; //No panic if this line is commented out + let loss = loss.sum(); + let g = loss.backward(); + } }