This commit is contained in:
nathaniel 2024-04-25 10:46:14 -04:00
parent a1bd14c5ae
commit 0a1f107e45
2 changed files with 18 additions and 9 deletions

8
Cargo.lock generated
View File

@ -3712,14 +3712,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "refactor"
version = "0.14.0"
dependencies = [
"burn",
"serde",
]
[[package]]
name = "regex"
version = "1.10.4"

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(ad_maxmin)]
mod tests {
use super::*;
use burn_tensor::Data;
use burn_tensor::{Data, Int, Tensor};
#[test]
fn should_diff_max_dim() {
@ -48,4 +48,21 @@ mod tests {
.to_data()
.assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5);
}
#[test]
fn test_max_dim_complex() {
let device = Default::default();
let a: Vec<f32> = vec![0.0, 0.0];
let b = [0, 0];
let b: Tensor<TestAutodiffBackend, 2, Int> =
Tensor::from_data(Data::from(b.as_slice()), &device).reshape([2, 1]);
let a = Tensor::from_data(Data::from(a.as_slice()), &device)
.reshape([2, 1])
.require_grad();
let loss = a.gather(1, b);
let loss = loss.clone().max_dim(0) + loss; //No panic if this line is commented out
let loss = loss.sum();
let g = loss.backward();
}
}