Upsample grad (#1420)
* encode size of upsample in enum * working convolution method for limited 2d kernels * add test for sf 3 interpolation * add higher dimensional tests, fix to work with multichannel input * Remove commented out line. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
9bd94c1ffa
commit
18eb87f25f
|
@ -114,7 +114,7 @@ impl Tensor {
|
|||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
|
@ -350,9 +350,27 @@ impl Tensor {
|
|||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
} => {
|
||||
let (_n, c, h, w) = arg.dims4()?;
|
||||
if target_h % h != 0 || target_w % w != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale_h = target_h / h;
|
||||
let scale_w = target_w / w;
|
||||
|
||||
if scale_h != scale_w {
|
||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
||||
};
|
||||
let kernel =
|
||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
|
|
|
@ -132,7 +132,11 @@ pub enum Op {
|
|||
},
|
||||
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D(Tensor),
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
target_w: usize,
|
||||
},
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
|
|
|
@ -994,7 +994,11 @@ impl Tensor {
|
|||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||
|
|
|
@ -270,6 +270,166 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+7+8 = 18
|
||||
// 3+4+9+10 = 26
|
||||
// 5+6+11+12 = 34
|
||||
// row 2
|
||||
// 13+14+19+20 = 66
|
||||
// 15+16+21+22 = 74
|
||||
// 17+18+23+24 = 82
|
||||
// row 3
|
||||
// 25+26+31+32 = 114
|
||||
// 27+28+33+34 = 122
|
||||
// 29+30+35+36 = 130
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+3+7+8+9+13+14+15 = 72
|
||||
// 4+5+6+10+11+12+16+17+18 = 99
|
||||
// row 2
|
||||
// 19+20+21+25+26+27+31+32+33 = 234
|
||||
// 22+23+24+28+29+30+34+35+36 = 243
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[72_f32, 99.], [234., 261.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(
|
||||
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue