candle/candle-core/tests/layout_tests.rs

153 lines
5.1 KiB
Rust

use candle::{test_device, Device, IndexOp, Result, Tensor};
use candle_core as candle;
fn contiguous(device: &Device) -> Result<()> {
let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?;
assert_eq!(
tensor.to_vec3::<u32>()?,
&[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]
);
assert_eq!(
tensor.t()?.contiguous()?.to_vec3::<u32>()?,
&[
[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]],
[[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]]
]
);
assert_eq!(
tensor.transpose(0, 1)?.contiguous()?.to_vec3::<u32>()?,
&[
[[0, 1, 2, 3], [12, 13, 14, 15]],
[[4, 5, 6, 7], [16, 17, 18, 19]],
[[8, 9, 10, 11], [20, 21, 22, 23]]
]
);
assert_eq!(
tensor.transpose(0, 1)?.flatten_all()?.to_vec1::<u32>()?,
&[0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23]
);
assert_eq!(
tensor
.i(1..)?
.transpose(0, 1)?
.contiguous()?
.to_vec3::<u32>()?,
&[[[12, 13, 14, 15]], [[16, 17, 18, 19]], [[20, 21, 22, 23]]]
);
assert_eq!(
tensor.transpose(0, 2)?.contiguous()?.to_vec3::<u32>()?,
&[
[[0, 12], [4, 16], [8, 20]],
[[1, 13], [5, 17], [9, 21]],
[[2, 14], [6, 18], [10, 22]],
[[3, 15], [7, 19], [11, 23]]
]
);
Ok(())
}
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
#[test]
fn strided_blocks() -> Result<()> {
use candle::Device::Cpu;
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 0);
assert_eq!(len, 24);
}
candle::StridedBlocks::MultipleBlocks { .. } => {
panic!("unexpected block structure")
}
};
let tensor = Tensor::arange(0u32, 26u32, &Cpu)?
.i(2..)?
.reshape((2, 3, 4))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 2);
assert_eq!(len, 24);
}
candle::StridedBlocks::MultipleBlocks { .. } => {
panic!("unexpected block structure")
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i(1)?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 12);
assert_eq!(len, 12);
}
candle::StridedBlocks::MultipleBlocks { .. } => {
panic!("unexpected block structure")
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?.contiguous()?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 0);
assert_eq!(len, 8);
assert_eq!(tensor.to_vec2::<u32>()?, &[[4, 5, 6, 7], [16, 17, 18, 19]]);
}
candle::StridedBlocks::MultipleBlocks { .. } => {
panic!("unexpected block structure")
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_len,
block_start_index,
} => {
assert_eq!(block_len, 4);
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.t()?.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
assert_eq!(block_len, 1);
assert_eq!(
block_start_index.collect::<Vec<_>>(),
&[
0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15,
19, 23
]
)
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.transpose(0, 1)?.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
assert_eq!(block_len, 4);
assert_eq!(
block_start_index.collect::<Vec<_>>(),
&[0, 12, 4, 16, 8, 20]
)
}
};
Ok(())
}