mirror of https://github.com/tracel-ai/burn.git
Add 0-dim tensor checks for creation ops and validate TensorData shape w/ num values (#2137)
This commit is contained in:
parent
16239db252
commit
d2699022df
|
@ -58,7 +58,9 @@ where
|
|||
|
||||
/// Create an empty tensor of the given shape.
|
||||
pub fn empty<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
|
||||
Self::new(K::empty(shape.into(), device))
|
||||
let shape = shape.into();
|
||||
check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
|
||||
Self::new(K::empty(shape, device))
|
||||
}
|
||||
|
||||
/// Returns the dimensions of the current tensor.
|
||||
|
@ -717,7 +719,10 @@ where
|
|||
T: Into<TensorData>,
|
||||
{
|
||||
let data = data.into();
|
||||
check!(TensorCheck::from_data::<D>(data.shape.as_slice()));
|
||||
check!(TensorCheck::creation_ops::<D>(
|
||||
"From Data",
|
||||
data.shape.as_slice()
|
||||
));
|
||||
Self::new(K::from_data(data, device))
|
||||
}
|
||||
|
||||
|
|
|
@ -80,12 +80,20 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn from_data<const D: usize>(dims: &[usize]) -> Self {
|
||||
pub(crate) fn creation_ops<const D: usize>(ops: &str, dims: &[usize]) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if D == 0 {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new("Tried to create a 0-dim tensor, which is invalid.")
|
||||
.details(format!("Tensor rank: '{D}', given dimensions: '{dims:?}'.")),
|
||||
);
|
||||
}
|
||||
|
||||
if dims.len() != D {
|
||||
check = check.register(
|
||||
"From Data",
|
||||
ops,
|
||||
TensorError::new("Given dimensions differ from the tensor rank.")
|
||||
.details(format!("Tensor rank: '{D}', given dimensions: '{dims:?}'.")),
|
||||
);
|
||||
|
|
|
@ -101,12 +101,16 @@ where
|
|||
|
||||
/// Create a tensor of the given shape where each element is zero.
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
|
||||
Self::new(K::zeros(shape.into(), device))
|
||||
let shape = shape.into();
|
||||
check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
|
||||
Self::new(K::zeros(shape, device))
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is one.
|
||||
pub fn ones<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
|
||||
Self::new(K::ones(shape.into(), device))
|
||||
let shape = shape.into();
|
||||
check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
|
||||
Self::new(K::ones(shape, device))
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is equal to the provided value.
|
||||
|
@ -115,7 +119,9 @@ where
|
|||
fill_value: E,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
Self::new(K::full(shape.into(), fill_value, device))
|
||||
let shape = shape.into();
|
||||
check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
|
||||
Self::new(K::full(shape, fill_value, device))
|
||||
}
|
||||
|
||||
/// Aggregate all elements in the tensor with the mean operation.
|
||||
|
|
|
@ -63,8 +63,18 @@ impl TensorData {
|
|||
// Ensure `E` satisfies the `Pod` trait requirements
|
||||
assert_eq!(core::mem::size_of::<E>() % core::mem::size_of::<u8>(), 0);
|
||||
|
||||
// Ensure shape is valid
|
||||
let shape = shape.into();
|
||||
let shape_numel = Self::numel(&shape);
|
||||
let numel = value.len();
|
||||
assert_eq!(
|
||||
shape_numel, numel,
|
||||
"Shape {:?} is invalid for input of size {:?}",
|
||||
shape, numel,
|
||||
);
|
||||
|
||||
let factor = core::mem::size_of::<E>() / core::mem::size_of::<u8>();
|
||||
let len = value.len() * factor;
|
||||
let len = numel * factor;
|
||||
let capacity = value.capacity() * factor;
|
||||
let ptr = value.as_mut_ptr();
|
||||
|
||||
|
@ -74,7 +84,7 @@ impl TensorData {
|
|||
|
||||
Self {
|
||||
bytes,
|
||||
shape: shape.into(),
|
||||
shape,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue