mirror of https://github.com/tracel-ai/burn.git
Add test
This commit is contained in:
parent
a1bd14c5ae
commit
0a1f107e45
|
@ -3712,14 +3712,6 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "refactor"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.4"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(ad_maxmin)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_diff_max_dim() {
|
||||
|
@ -48,4 +48,21 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_dim_complex() {
|
||||
let device = Default::default();
|
||||
let a: Vec<f32> = vec![0.0, 0.0];
|
||||
let b = [0, 0];
|
||||
let b: Tensor<TestAutodiffBackend, 2, Int> =
|
||||
Tensor::from_data(Data::from(b.as_slice()), &device).reshape([2, 1]);
|
||||
let a = Tensor::from_data(Data::from(a.as_slice()), &device)
|
||||
.reshape([2, 1])
|
||||
.require_grad();
|
||||
|
||||
let loss = a.gather(1, b);
|
||||
let loss = loss.clone().max_dim(0) + loss; //No panic if this line is commented out
|
||||
let loss = loss.sum();
|
||||
let g = loss.backward();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue