mirror of https://github.com/tracel-ai/burn.git
fix/ndarray: remove reversed axes check (#1058)
This commit is contained in:
parent
4608cd97d2
commit
fc97a28f16
|
@ -72,7 +72,9 @@ where
|
|||
.unwrap()
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
// Transform column-major layout into row-major (standard) layout. (fix #1053)
|
||||
let array = NdArrayTensor { array };
|
||||
Self::reshape(array.clone(), array.shape())
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
|
@ -496,3 +498,42 @@ fn arg<E: NdArrayElement, const D: usize>(
|
|||
array: output.into_shared(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_generate_row_major_layout_for_cat() {
|
||||
let expected_shape: &[usize] = &[4, 6, 2];
|
||||
let expected_strides: &[isize] = &[12, 2, 1];
|
||||
let expected_array = NdArrayTensor::from_data(Data::<i32, 3>::from([
|
||||
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
|
||||
[[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
|
||||
[[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
|
||||
[[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
|
||||
]));
|
||||
|
||||
// unsqueeze dim on the outermost axis
|
||||
let array = NdArrayOps::reshape(
|
||||
NdArrayTensor::from_data(Data::<i32, 2>::from([
|
||||
[1, 2, 3, 4, 5, 6],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
[13, 14, 15, 16, 17, 18],
|
||||
[19, 20, 21, 22, 23, 24],
|
||||
])),
|
||||
Shape::from([4, 6, 1]),
|
||||
);
|
||||
let zeros = NdArrayTensor::<i32, 3>::from_data(Data::zeros([4, 6, 1]));
|
||||
// make `ndarray` concatenates array on the outermost axis
|
||||
let array = NdArrayOps::cat([array, zeros].to_vec(), 2);
|
||||
|
||||
assert!(array.array.is_standard_layout());
|
||||
assert_eq!(array.array.shape(), expected_shape);
|
||||
assert_eq!(array.array.strides(), expected_strides);
|
||||
assert_eq!(
|
||||
array.array.into_iter().collect::<Vec<_>>(),
|
||||
expected_array.array.into_iter().collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,14 +63,7 @@ macro_rules! reshape {
|
|||
array $array:expr
|
||||
) => {{
|
||||
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
|
||||
let safe_into_shape =
|
||||
$array.is_standard_layout() ||
|
||||
(
|
||||
$array.ndim() > 1 &&
|
||||
$array.raw_view().reversed_axes().is_standard_layout()
|
||||
);
|
||||
|
||||
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
|
||||
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
|
||||
true => $array
|
||||
.into_shape(dim)
|
||||
.expect("Safe to change shape without relayout")
|
||||
|
|
|
@ -90,4 +90,21 @@ mod tests {
|
|||
|
||||
let output: Tensor<TestBackend, 4> = TestTensor::stack(vec![tensor_1, tensor_2], 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_generate_row_major_layout() {
|
||||
let data_expected = Data::from([
|
||||
[1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0],
|
||||
[7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0],
|
||||
[13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0],
|
||||
[19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0],
|
||||
]);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(1..25).reshape([4, 6]);
|
||||
let zeros: Tensor<TestBackend, 2, Int> = Tensor::zeros([4, 6]);
|
||||
let intersperse =
|
||||
Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]);
|
||||
|
||||
assert_eq!(data_expected, intersperse.into_data());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue