Support i64 in index-select on metal. (#1951)

* Support i64 in index-select on metal.

* Add some testing of index-select for all dtypes.
This commit is contained in:
Laurent Mazare 2024-03-27 16:30:07 +01:00 committed by GitHub
parent a9abde5f93
commit ab86cd37c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 38 deletions

View File

@ -1391,6 +1391,10 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16", (DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
} }

View File

@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?; let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(()) Ok(())
} }
@ -734,6 +736,8 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0] [9.0, 10.0, 11.0]
] ]
); );
for dtype in [DType::U8, DType::U32, DType::I64] {
let ids = ids.to_dtype(dtype)?;
let hs = t.index_select(&ids, 1)?; let hs = t.index_select(&ids, 1)?;
assert_eq!( assert_eq!(
hs.to_vec2::<f32>()?, hs.to_vec2::<f32>()?,
@ -772,6 +776,7 @@ fn index_select(device: &Device) -> Result<()> {
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]); assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?; let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
}
Ok(()) Ok(())
} }

View File

@ -187,6 +187,12 @@ kernel void NAME( \
} }
INDEX_OP(is_i64_f32, int64_t, float)
INDEX_OP(is_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif
INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half) INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__) #if defined(__HAVE_BFLOAT__)