Merge pull request #1479 from huggingface/upsample_metal
Adding upsample_nearest_2d.
This commit is contained in:
commit
eae3a20d43
|
@ -959,8 +959,39 @@ impl BackendStorage for MetalStorage {
|
|||
crate::bail!("upsample_nearest1d metal")
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
crate::bail!("upsample_nearest2d metal")
|
||||
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
// let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let strides = inp_l.stride();
|
||||
if dims.len() != 4 {
|
||||
crate::bail!("unexpected input shape for upsample {dims:?}")
|
||||
}
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "upsample_nearest2d_f32",
|
||||
dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"),
|
||||
};
|
||||
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_upsample_nearest_2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
dims,
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
&self.buffer,
|
||||
inp_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype))
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
|
|
|
@ -108,6 +108,47 @@ METAL_FUNC void im2col1d(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void upsample_nearest2d(
|
||||
constant size_t &w_out,
|
||||
constant size_t &h_out,
|
||||
constant float &w_scale,
|
||||
constant float &h_scale,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_s,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
if (tid >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Improve this.
|
||||
const size_t b_idx = tid / (w_out * h_out * c);
|
||||
const size_t c_idx = (tid / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (tid / h_out) % w_out;
|
||||
const size_t dst_h = tid % h_out;
|
||||
|
||||
size_t src_w = static_cast<size_t>(dst_w * w_scale);
|
||||
size_t src_h = static_cast<size_t>(dst_h * h_scale);
|
||||
if (src_w >= w_in) {
|
||||
src_w = w_in - 1;
|
||||
}
|
||||
if (src_h >= h_in) {
|
||||
src_h = h_in - 1;
|
||||
}
|
||||
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
|
||||
dst[tid] = src[src_i];
|
||||
}
|
||||
|
||||
#define IM2COL_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dst_numel, \
|
||||
|
@ -143,6 +184,21 @@ kernel void FN_NAME( \
|
|||
) { \
|
||||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_out, \
|
||||
constant size_t &h_out, \
|
||||
constant float &w_scale, \
|
||||
constant float &h_scale, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
|
@ -151,3 +207,7 @@ IM2COL_OP(uint32_t, im2col_u32)
|
|||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
||||
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
|
|
|
@ -1518,6 +1518,50 @@ pub fn call_im2col_strided(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_upsample_nearest_2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_w: usize,
|
||||
out_h: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let dst_el = out_w * out_h * shape[0] * shape[1];
|
||||
let scale_w = shape[2] as f32 / out_w as f32;
|
||||
let scale_h = shape[3] as f32 / out_h as f32;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
out_w,
|
||||
out_h,
|
||||
scale_w,
|
||||
scale_h,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue