Shape with holes (#770)

* Shape with holes.

* rustfmt.
This commit is contained in:
Laurent Mazare 2023-09-08 08:38:13 +01:00 committed by GitHub
parent cfcbec9fc7
commit 0e250aee4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 184 additions and 6 deletions

View File

@ -482,3 +482,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
impl<S: Into<Shape>> ShapeWithOneHole for S {
fn into_shape(self, _el_count: usize) -> Result<Shape> {
Ok(self.into())
}
}
impl ShapeWithOneHole for ((),) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
Ok(el_count.into())
}
}
impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
}
}
impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
}
}
impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
}
}
impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
}
}
impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
}
}

View File

@ -1685,12 +1685,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
/// as to match the number of elements in the tensor.
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@ -1700,10 +1703,14 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
///
/// let c = a.reshape((2, (), 1))?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),

View File

@ -57,10 +57,13 @@ fn main() -> Result<()> {
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
Ok(build_dir) =>
{
Ok(build_dir) => {
let path = PathBuf::from(build_dir);
path.canonicalize().expect(&format!("Directory doesn't exists: {} (the current directory is {})", &path.display(), std::env::current_dir()?.display()))
path.canonicalize().expect(&format!(
"Directory doesn't exists: {} (the current directory is {})",
&path.display(),
std::env::current_dir()?.display()
))
}
};
set_cuda_include_dir()?;