Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types * update bench calculation * enable testing all conv operations on metal
This commit is contained in:
parent
ec97c98e81
commit
9563a5fee4
|
@ -5,5 +5,6 @@ criterion_main!(
|
|||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
);
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(
|
||||
x: &Tensor,
|
||||
k: &Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) {
|
||||
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let t = Tensor::arange(0.0f32, 10000.0, device)
|
||||
.unwrap()
|
||||
.reshape((1, 4, 50, 50))
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
|
||||
let kernel = Tensor::arange(0.0f32, 100.0, device)
|
||||
.unwrap()
|
||||
.reshape((4, 1, 5, 5))
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
|
||||
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
|
||||
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
|
||||
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,4 +1,5 @@
|
|||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
|
|
@ -2,8 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
|
|||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
|
@ -1074,12 +1074,66 @@ impl BackendStorage for MetalStorage {
|
|||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose2D,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("Metal conv_tranpose2d not implemented")
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let (out_w, out_h) = (params.out_w(), params.out_h());
|
||||
let dst_el = params.c_out * out_w * out_h * params.b_size;
|
||||
|
||||
let dims = l.dims();
|
||||
if dims.len() != 4 {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4")
|
||||
}
|
||||
|
||||
let k_dims = kernel_l.dims();
|
||||
if k_dims.len() != 4 {
|
||||
crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4")
|
||||
}
|
||||
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose2d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose2d_f32",
|
||||
DType::F16 => "conv_transpose2d_f16",
|
||||
DType::BF16 => "conv_transpose2d_bf16",
|
||||
dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"),
|
||||
};
|
||||
|
||||
candle_metal_kernels::call_conv_transpose2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
CallConvTranspose2dCfg {
|
||||
dilation: params.dilation,
|
||||
stride: params.stride,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
c_out: params.c_out,
|
||||
out_h: out_h,
|
||||
out_w: out_w,
|
||||
b_size: params.b_size,
|
||||
input_dims: l.dims(),
|
||||
input_stride: l.stride(),
|
||||
kernel_dims: kernel_l.dims(),
|
||||
kernel_stride: kernel_l.stride(),
|
||||
input_offset: l.start_offset() * self.dtype.size_in_bytes(),
|
||||
kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(),
|
||||
},
|
||||
&self.buffer,
|
||||
&kernel.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
|
|
|
@ -163,33 +163,34 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
if !dev.is_metal() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
);
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
|
@ -198,44 +199,37 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||
[2.45, -2.3504],
|
||||
);
|
||||
|
||||
if !dev.is_metal() {
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[
|
||||
-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
|
||||
-3.5024
|
||||
],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[
|
||||
-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
|
||||
1.0171
|
||||
]
|
||||
]
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||
]
|
||||
);
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -290,11 +284,6 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
|||
]
|
||||
);
|
||||
|
||||
// conv-transposes are not implemented for metal
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
|
@ -397,9 +386,6 @@ print(w.grad[0])
|
|||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
// conv-transposes are not implemented for metal
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
|
|
|
@ -405,6 +405,86 @@ kernel void FN_NAME( \
|
|||
conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \
|
||||
} \
|
||||
|
||||
template <typename T, typename A>
|
||||
METAL_FUNC void conv_transpose2d(
|
||||
constant size_t &w_out,
|
||||
constant size_t &h_out,
|
||||
constant size_t &stride,
|
||||
constant size_t &padding,
|
||||
constant size_t &out_padding,
|
||||
constant size_t &dilation,
|
||||
constant size_t *input_dims,
|
||||
constant size_t *input_stride,
|
||||
constant size_t *k_dims,
|
||||
constant size_t *k_stride,
|
||||
device const T *src,
|
||||
device const T *k,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
const size_t h_k = k_dims[2];
|
||||
const size_t w_k = k_dims[3];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = input_dims[1];
|
||||
const size_t h_in = input_dims[2];
|
||||
const size_t w_in = input_dims[3];
|
||||
|
||||
if (tid >= input_dims[0] * c_out * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (w_out * h_out * c_out);
|
||||
const size_t dst_c_idx = (tid / (w_out * h_out)) % c_out;
|
||||
const size_t out_y = (tid / w_out) % h_out;
|
||||
const size_t out_x = tid % w_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * input_stride[0];
|
||||
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
|
||||
const int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
const int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= w_in) continue;
|
||||
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
|
||||
const int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
|
||||
if (inp_y_stride < 0 || inp_y_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
const int inp_y = inp_y_stride / stride;
|
||||
if (inp_y >= h_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * input_stride[1] + inp_y * input_stride[2] + inp_x * input_stride[3];
|
||||
const size_t k_idx = src_c_idx * k_stride[0] + dst_c_idx * k_stride[1] + k_y * k_stride[2] + k_x * k_stride[3];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_out, \
|
||||
constant size_t &h_out, \
|
||||
constant size_t &stride, \
|
||||
constant size_t &padding, \
|
||||
constant size_t &out_padding, \
|
||||
constant size_t &dilation, \
|
||||
constant size_t *input_dims, \
|
||||
constant size_t *input_stride, \
|
||||
constant size_t *k_dims, \
|
||||
constant size_t *k_stride, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *k, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, input_dims, input_stride, k_dims, k_stride, src, k, dst, tid); \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
@ -439,4 +519,10 @@ CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
|
|||
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose1d_bf16)
|
||||
#endif
|
||||
|
||||
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(half, float, conv_transpose2d_f16)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
|
||||
#endif
|
|
@ -1970,5 +1970,63 @@ pub fn call_conv_transpose1d(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub struct CallConvTranspose2dCfg<'a> {
|
||||
pub dilation: usize,
|
||||
pub stride: usize,
|
||||
pub padding: usize,
|
||||
pub output_padding: usize,
|
||||
pub c_out: usize,
|
||||
pub out_w: usize,
|
||||
pub out_h: usize,
|
||||
pub b_size: usize,
|
||||
pub input_dims: &'a [usize],
|
||||
pub input_stride: &'a [usize],
|
||||
pub kernel_dims: &'a [usize],
|
||||
pub kernel_stride: &'a [usize],
|
||||
pub input_offset: usize,
|
||||
pub kernel_offset: usize,
|
||||
}
|
||||
|
||||
pub fn call_conv_transpose2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
cfg: CallConvTranspose2dCfg,
|
||||
input: &Buffer,
|
||||
kernel: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
cfg.out_w,
|
||||
cfg.out_h,
|
||||
cfg.stride,
|
||||
cfg.padding,
|
||||
cfg.output_padding,
|
||||
cfg.dilation,
|
||||
cfg.input_dims,
|
||||
cfg.input_stride,
|
||||
cfg.kernel_dims,
|
||||
cfg.kernel_stride,
|
||||
(input, cfg.input_offset),
|
||||
(kernel, cfg.kernel_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
|
Loading…
Reference in New Issue