fix/ndarray: remove reversed axes check (#1058)

This commit is contained in:
AuruTus 2023-12-15 07:09:24 +08:00 committed by GitHub
parent 4608cd97d2
commit fc97a28f16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 9 deletions

View File

@ -72,7 +72,9 @@ where
.unwrap()
.into_shared();
NdArrayTensor { array }
// Transform column-major layout into row-major (standard) layout. (fix #1053)
let array = NdArrayTensor { array };
Self::reshape(array.clone(), array.shape())
}
fn to_slice_args<const D1: usize, const D2: usize>(
@ -496,3 +498,42 @@ fn arg<E: NdArrayElement, const D: usize>(
array: output.into_shared(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_generate_row_major_layout_for_cat() {
let expected_shape: &[usize] = &[4, 6, 2];
let expected_strides: &[isize] = &[12, 2, 1];
let expected_array = NdArrayTensor::from_data(Data::<i32, 3>::from([
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
[[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
[[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
[[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
]));
// unsqueeze dim on the outermost axis
let array = NdArrayOps::reshape(
NdArrayTensor::from_data(Data::<i32, 2>::from([
[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
])),
Shape::from([4, 6, 1]),
);
let zeros = NdArrayTensor::<i32, 3>::from_data(Data::zeros([4, 6, 1]));
// make `ndarray` concatenates array on the outermost axis
let array = NdArrayOps::cat([array, zeros].to_vec(), 2);
assert!(array.array.is_standard_layout());
assert_eq!(array.array.shape(), expected_shape);
assert_eq!(array.array.strides(), expected_strides);
assert_eq!(
array.array.into_iter().collect::<Vec<_>>(),
expected_array.array.into_iter().collect::<Vec<_>>(),
);
}
}

View File

@ -63,14 +63,7 @@ macro_rules! reshape {
array $array:expr
) => {{
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let safe_into_shape =
$array.is_standard_layout() ||
(
$array.ndim() > 1 &&
$array.raw_view().reversed_axes().is_standard_layout()
);
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
true => $array
.into_shape(dim)
.expect("Safe to change shape without relayout")

View File

@ -90,4 +90,21 @@ mod tests {
let output: Tensor<TestBackend, 4> = TestTensor::stack(vec![tensor_1, tensor_2], 3);
}
#[test]
fn should_generate_row_major_layout() {
let data_expected = Data::from([
[1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0],
[7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0],
[13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0],
[19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0],
]);
let tensor = Tensor::<TestBackend, 1, Int>::arange(1..25).reshape([4, 6]);
let zeros: Tensor<TestBackend, 2, Int> = Tensor::zeros([4, 6]);
let intersperse =
Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]);
assert_eq!(data_expected, intersperse.into_data());
}
}