Add 0-dim tensor checks for creation ops and validate TensorData shape w/ num values (#2137)

This commit is contained in:
Guillaume Lagrange 2024-08-15 09:54:22 -04:00 committed by GitHub
parent 16239db252
commit d2699022df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 9 deletions

View File

@ -58,7 +58,9 @@ where
/// Create an empty tensor of the given shape.
pub fn empty<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
Self::new(K::empty(shape.into(), device))
let shape = shape.into();
check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
Self::new(K::empty(shape, device))
}
/// Returns the dimensions of the current tensor.
@ -717,7 +719,10 @@ where
T: Into<TensorData>,
{
let data = data.into();
check!(TensorCheck::from_data::<D>(data.shape.as_slice()));
check!(TensorCheck::creation_ops::<D>(
"From Data",
data.shape.as_slice()
));
Self::new(K::from_data(data, device))
}

View File

@ -80,12 +80,20 @@ impl TensorCheck {
check
}
pub(crate) fn from_data<const D: usize>(dims: &[usize]) -> Self {
pub(crate) fn creation_ops<const D: usize>(ops: &str, dims: &[usize]) -> Self {
let mut check = Self::Ok;
if D == 0 {
check = check.register(
ops,
TensorError::new("Tried to create a 0-dim tensor, which is invalid.")
.details(format!("Tensor rank: '{D}', given dimensions: '{dims:?}'.")),
);
}
if dims.len() != D {
check = check.register(
"From Data",
ops,
TensorError::new("Given dimensions differ from the tensor rank.")
.details(format!("Tensor rank: '{D}', given dimensions: '{dims:?}'.")),
);

View File

@ -101,12 +101,16 @@ where
/// Create a tensor of the given shape where each element is zero.
pub fn zeros<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
Self::new(K::zeros(shape.into(), device))
let shape = shape.into();
check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
Self::new(K::zeros(shape, device))
}
/// Create a tensor of the given shape where each element is one.
pub fn ones<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
Self::new(K::ones(shape.into(), device))
let shape = shape.into();
check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
Self::new(K::ones(shape, device))
}
/// Create a tensor of the given shape where each element is equal to the provided value.
@ -115,7 +119,9 @@ where
fill_value: E,
device: &B::Device,
) -> Self {
Self::new(K::full(shape.into(), fill_value, device))
let shape = shape.into();
check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
Self::new(K::full(shape, fill_value, device))
}
/// Aggregate all elements in the tensor with the mean operation.

View File

@ -63,8 +63,18 @@ impl TensorData {
// Ensure `E` satisfies the `Pod` trait requirements
assert_eq!(core::mem::size_of::<E>() % core::mem::size_of::<u8>(), 0);
// Ensure shape is valid
let shape = shape.into();
let shape_numel = Self::numel(&shape);
let numel = value.len();
assert_eq!(
shape_numel, numel,
"Shape {:?} is invalid for input of size {:?}",
shape, numel,
);
let factor = core::mem::size_of::<E>() / core::mem::size_of::<u8>();
let len = value.len() * factor;
let len = numel * factor;
let capacity = value.capacity() * factor;
let ptr = value.as_mut_ptr();
@ -74,7 +84,7 @@ impl TensorData {
Self {
bytes,
shape: shape.into(),
shape,
dtype,
}
}