Extend supported dtypes for metal (im2col & upsample_2d) (#1938)

* update im2col dtype implementations

* update dtypes for upsample
This commit is contained in:
Thomas Santerre 2024-03-26 01:48:56 -04:00 committed by GitHub
parent 196765e995
commit f5dfe883d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 0 deletions

View File

@ -1038,6 +1038,10 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
DType::F16 => "im2col_f16",
DType::BF16 => "im2col_bf16",
DType::U8 => "im2col_u8",
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col_strided(
@ -1250,6 +1254,10 @@ impl BackendStorage for MetalStorage {
}
let name = match self.dtype {
DType::F32 => "upsample_nearest2d_f32",
DType::F16 => "upsample_nearest2d_f16",
DType::BF16 => "upsample_nearest2d_bf16",
DType::U8 => "upsample_nearest2d_u8",
DType::U32 => "upsample_nearest2d_u32",
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
};

View File

@ -486,16 +486,24 @@ kernel void FN_NAME( \
} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(half, im2col_f16)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
#if defined(__HAVE_BFLOAT__)
IM2COL_OP(bfloat, im2col_bf16)
#endif
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
#if defined(__HAVE_BFLOAT__)
UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16)
#endif
MAXPOOL2D_OP(float, max_pool2d_f32)
MAXPOOL2D_OP(half, max_pool2d_f16)