Add a couple cast metal kernels. (#2479)
This commit is contained in:
parent
ebf722b446
commit
6eea45a761
|
@ -412,19 +412,42 @@ impl BackendStorage for MetalStorage {
|
|||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F16) => "cast_u32_f16_strided",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(DType::F16, DType::I64) => "cast_f16_i64_strided",
|
||||
(DType::F16, DType::U32) => "cast_f16_u32_strided",
|
||||
(DType::F16, DType::U8) => "cast_f16_u8_strided",
|
||||
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F32, DType::I64) => "cast_f32_i64_strided",
|
||||
(DType::F32, DType::U32) => "cast_f32_u32_strided",
|
||||
(DType::F32, DType::U8) => "cast_f32_u8_strided",
|
||||
|
||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
|
||||
(DType::I64, DType::F16) => "cast_i64_f16_strided",
|
||||
(DType::I64, DType::U32) => "cast_i64_u32_strided",
|
||||
(DType::I64, DType::U8) => "cast_i64_u8_strided",
|
||||
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16_strided",
|
||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
|
||||
(DType::U8, DType::F16) => "cast_u8_f16_strided",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue