mirror of https://github.com/tracel-ai/burn.git
Add test
This commit is contained in:
parent
a1bd14c5ae
commit
0a1f107e45
|
@ -3712,14 +3712,6 @@ dependencies = [
|
||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "refactor"
|
|
||||||
version = "0.14.0"
|
|
||||||
dependencies = [
|
|
||||||
"burn",
|
|
||||||
"serde",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.10.4"
|
version = "1.10.4"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#[burn_tensor_testgen::testgen(ad_maxmin)]
|
#[burn_tensor_testgen::testgen(ad_maxmin)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use burn_tensor::Data;
|
use burn_tensor::{Data, Int, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_max_dim() {
|
fn should_diff_max_dim() {
|
||||||
|
@ -48,4 +48,21 @@ mod tests {
|
||||||
.to_data()
|
.to_data()
|
||||||
.assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5);
|
.assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_dim_complex() {
|
||||||
|
let device = Default::default();
|
||||||
|
let a: Vec<f32> = vec![0.0, 0.0];
|
||||||
|
let b = [0, 0];
|
||||||
|
let b: Tensor<TestAutodiffBackend, 2, Int> =
|
||||||
|
Tensor::from_data(Data::from(b.as_slice()), &device).reshape([2, 1]);
|
||||||
|
let a = Tensor::from_data(Data::from(a.as_slice()), &device)
|
||||||
|
.reshape([2, 1])
|
||||||
|
.require_grad();
|
||||||
|
|
||||||
|
let loss = a.gather(1, b);
|
||||||
|
let loss = loss.clone().max_dim(0) + loss; //No panic if this line is commented out
|
||||||
|
let loss = loss.sum();
|
||||||
|
let g = loss.backward();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue