Merge remote-tracking branch 'origin/main' into cuda-conv-tr1d
This commit is contained in:
commit
53f951f6e2
|
@ -609,28 +609,41 @@ impl BackendStorage for MetalStorage {
|
|||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
(DType::U8, DType::F16) => "cast_u8_f16",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F32, DType::I64) => "cast_f32_i64",
|
||||
(DType::F32, DType::U32) => "cast_f32_u32",
|
||||
(DType::F32, DType::U8) => "cast_f32_u8",
|
||||
|
||||
(DType::I64, DType::BF16) => "cast_i64_bf16",
|
||||
(DType::I64, DType::F16) => "cast_i64_f16",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::I64, DType::U32) => "cast_i64_u32",
|
||||
(DType::I64, DType::U8) => "cast_i64_u8",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::F16, DType::I64) => "cast_f16_i64",
|
||||
(DType::F16, DType::U32) => "cast_f16_u32",
|
||||
(DType::F16, DType::U8) => "cast_f16_u8",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
(DType::BF16, DType::I64) => "cast_bf16_i64",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
|
|
|
@ -109,8 +109,7 @@ fn main() -> Result<()> {
|
|||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
codes
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
|
|
|
@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \
|
|||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
// u32
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u8
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
// f16
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
|
||||
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
|
||||
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
|
||||
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
|
||||
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
|
||||
#endif
|
||||
|
||||
// f32
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
|
||||
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
|
||||
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
||||
// bf16
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
|
@ -292,7 +292,7 @@ fn binary_ops_bf16() {
|
|||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
|
@ -319,107 +319,189 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
fn cast_f32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
// f32 -> f16
|
||||
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
// f32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f32 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
fn cast_f16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
// f16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
// f16 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
fn cast_bf16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
// bf16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
// bf16 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// bf16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// bf16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// bf16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
fn cast_u32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
// u32 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
// u32 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// u32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
fn cast_u8() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
// u8 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
// u8 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u8 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u8 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// u8 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
fn cast_i64() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
// i64 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
// i64 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
// i64 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
// i64 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
// i64 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
|
|
Loading…
Reference in New Issue