add max_pool2d (#371)
Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
This commit is contained in:
parent
1892bd139c
commit
a5c5a893aa
|
@ -46,6 +46,7 @@ pub trait BackendStorage: Sized {
|
|||
) -> Result<Self>;
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
|
|
|
@ -88,6 +88,7 @@ impl Tensor {
|
|||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
|
@ -172,6 +173,7 @@ impl Tensor {
|
|||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
|
||||
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
|
||||
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
|
|
|
@ -674,6 +674,48 @@ impl Map1 for AvgPool2D {
|
|||
}
|
||||
}
|
||||
|
||||
struct MaxPool2D((usize, usize), (usize, usize));
|
||||
|
||||
impl Map1 for MaxPool2D {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
|
||||
let (k_h, k_w) = self.0;
|
||||
let (s_h, s_w) = self.1;
|
||||
let (b_sz, c, h, w) = layout.shape().dims4()?;
|
||||
let stride = layout.stride();
|
||||
let (stride_h, stride_w) = (stride[2], stride[3]);
|
||||
let h_out = (h - k_h) / s_h + 1;
|
||||
let w_out = (w - k_w) / s_w + 1;
|
||||
let src_index = layout.start_offset();
|
||||
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
|
||||
for b_idx in 0..b_sz {
|
||||
let dst = &mut dst[b_idx * c * h_out * w_out..];
|
||||
let src_index = src_index + b_idx * stride[0];
|
||||
for c_idx in 0..c {
|
||||
let dst = &mut dst[c_idx * h_out * w_out..];
|
||||
let src_index = src_index + c_idx * stride[1];
|
||||
for h_idx in 0..h_out {
|
||||
for w_idx in 0..w_out {
|
||||
let mut largest =
|
||||
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
|
||||
for m in 0..k_h {
|
||||
for n in 0..k_w {
|
||||
let m = s_h * h_idx + m;
|
||||
let n = s_w * w_idx + n;
|
||||
if largest < src[src_index + m * stride_h + n * stride_w] {
|
||||
largest = src[src_index + m * stride_h + n * stride_w]
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[h_idx * w_out + w_idx] = largest;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
|
@ -1664,6 +1706,15 @@ impl BackendStorage for CpuStorage {
|
|||
AvgPool2D(kernel_size, stride).map(self, layout)
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
MaxPool2D(kernel_size, stride).map(self, layout)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
UpsampleNearest2D(h, w).map(self, layout)
|
||||
}
|
||||
|
|
|
@ -1395,6 +1395,10 @@ impl BackendStorage for CudaStorage {
|
|||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
|
|
@ -134,6 +134,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
|
|
@ -93,6 +93,13 @@ pub enum Op {
|
|||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
MaxPool2D {
|
||||
arg: Tensor,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest2D(Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
|
|
@ -311,6 +311,24 @@ impl Storage {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
|
|
|
@ -872,6 +872,22 @@ impl Tensor {
|
|||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.max_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -14,6 +14,18 @@ fn avg_pool2d() -> anyhow::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_pool2d() -> anyhow::Result<()> {
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
|
||||
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test corresponds to the following PyTorch script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
|
Loading…
Reference in New Issue