feat(bf16): add cast support + tests for cast + bin ops (#1524)
This commit is contained in:
parent
9f0c99f0c1
commit
402349d120
|
@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
|
|||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
||||
|
|
|
@ -590,14 +590,26 @@ impl BackendStorage for MetalStorage {
|
|||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
|
@ -1131,8 +1143,12 @@ impl BackendStorage for MetalStorage {
|
|||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||
}
|
||||
|
@ -1322,6 +1338,7 @@ impl MetalStorage {
|
|||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
|
@ -1332,6 +1349,18 @@ impl MetalStorage {
|
|||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
|
||||
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
|
||||
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
|
||||
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
|
||||
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
|
||||
|
||||
("add", DType::I64) => (contiguous::add::I64, self.dtype),
|
||||
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
|
||||
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
|
||||
|
@ -1342,6 +1371,7 @@ impl MetalStorage {
|
|||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||
|
||||
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||
|
@ -1352,6 +1382,7 @@ impl MetalStorage {
|
|||
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||
|
||||
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||
|
@ -1362,6 +1393,7 @@ impl MetalStorage {
|
|||
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
@ -1395,6 +1427,7 @@ impl MetalStorage {
|
|||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
|
@ -1407,6 +1440,20 @@ impl MetalStorage {
|
|||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
|
||||
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
|
||||
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
|
||||
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
|
||||
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
|
||||
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
|
||||
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
|
||||
|
||||
("badd", DType::I64) => (strided::add::I64, self.dtype),
|
||||
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
|
||||
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
|
||||
|
@ -1419,6 +1466,7 @@ impl MetalStorage {
|
|||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||
|
||||
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||
|
@ -1431,6 +1479,7 @@ impl MetalStorage {
|
|||
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||
|
||||
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||
|
@ -1443,6 +1492,7 @@ impl MetalStorage {
|
|||
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
|
|
@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
|
|||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
half = { version = "2.3.1", features = [
|
||||
"num-traits",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
|
|
|
@ -28,7 +28,7 @@ kernel void FN_NAME( \
|
|||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
|
@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
|
|||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
|
@ -59,6 +86,15 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
|||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
|
|
@ -174,6 +174,9 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
|||
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use metal::{Buffer, Device, MTLResourceOptions};
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
|
@ -248,6 +248,34 @@ fn binary_add_f32() {
|
|||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_ops_bf16() {
|
||||
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
|
||||
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect();
|
||||
|
||||
macro_rules! binary_op {
|
||||
($opname:ident, $opexpr:expr) => {{
|
||||
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
|
||||
let expected: Vec<bf16> = lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
|
||||
.collect();
|
||||
assert_eq!(results, expected);
|
||||
}};
|
||||
}
|
||||
|
||||
binary_op!(add, |x, y| x + y);
|
||||
binary_op!(sub, |x, y| x - y);
|
||||
binary_op!(mul, |x, y| x * y);
|
||||
binary_op!(div, |x, y| x / y);
|
||||
binary_op!(min, |x: bf16, y| x.min(y));
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
|
@ -296,6 +324,89 @@ fn cast_u32_f32() {
|
|||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
|
@ -396,14 +507,14 @@ fn index_select() {
|
|||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [2, 5];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
|
@ -419,20 +530,46 @@ fn index_select_f16() {
|
|||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u32_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u8_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u8, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 1;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
|
@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||
shape: &[usize],
|
||||
ids: &[I],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
|
@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
call_index_select(
|
||||
|
|
Loading…
Reference in New Issue