Merge remote-tracking branch 'origin/main' into cuda-conv-tr1d

This commit is contained in:
laurent 2024-03-17 21:17:56 +01:00
commit 53f951f6e2
4 changed files with 221 additions and 94 deletions

View File

@ -609,28 +609,41 @@ impl BackendStorage for MetalStorage {
let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::U8, DType::F16) => "cast_u8_f16",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::I64) => "cast_f32_i64",
(DType::F32, DType::U32) => "cast_f32_u32",
(DType::F32, DType::U8) => "cast_f32_u8",
(DType::I64, DType::BF16) => "cast_i64_bf16",
(DType::I64, DType::F16) => "cast_i64_f16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::I64, DType::U32) => "cast_i64_u32",
(DType::I64, DType::U8) => "cast_i64_u8",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::F16, DType::I64) => "cast_f16_i64",
(DType::F16, DType::U32) => "cast_f16_u32",
(DType::F16, DType::U8) => "cast_f16_u8",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
(DType::BF16, DType::I64) => "cast_bf16_i64",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")

View File

@ -109,8 +109,7 @@ fn main() -> Result<()> {
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
codes
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let (pcm, sample_rate) = pcm_decode(args.in_file)?;

View File

@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
// u32
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
#endif
// u8
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
#endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
#endif
// f16
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
// i64
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
#endif
// f32
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif
// bf16
#if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif

View File

@ -292,7 +292,7 @@ fn binary_ops_bf16() {
binary_op!(max, |x: bf16, y| x.max(y));
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
@ -319,107 +319,189 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
}
#[test]
fn cast_u32_f32() {
let v = vec![1u32, 2, 3];
let results = cast(&v, "cast_u32_f32");
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
fn cast_f32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let v = vec![1.0f32, 2.0, 3.0];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
// f32 -> f16
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
assert_eq!(results, v_f16);
let v = vec![1.0f32; 10_000];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results.len(), 10_000);
assert_eq!(&results[..10], vec![1.0f32; 10]);
assert_eq!(results, vec![1.0f32; 10_000]);
// f32 -> bf16
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
assert_eq!(results, v_bf16);
// f32 -> u32
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
assert_eq!(results, v_u32);
// f32 -> u8
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
assert_eq!(results, v_u8);
// f32 -> i64
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_bf16_u32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
fn cast_f16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
// f16 -> f32
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// f16 -> bf16
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
assert_eq!(results, v_bf16);
// f16 -> u32
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
assert_eq!(results, v_u32);
// f16 -> u8
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
assert_eq!(results, v_u8);
// f16 -> i64
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_bf16_f32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
fn cast_bf16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
// bf16 -> f32
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// bf16 -> f16
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
assert_eq!(results, v_f16);
// bf16 -> u32
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
assert_eq!(results, v_u32);
// bf16 -> u8
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
assert_eq!(results, v_u8);
// bf16 -> i64
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_u8_bf16() {
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
fn cast_u32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
let expected: Vec<bf16> = input
.iter()
.map(|v| bf16::from_f32(*v as f32))
.collect::<Vec<_>>();
// u32 -> f32
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// u32 -> f16
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
assert_eq!(results, v_f16);
// u32 -> bf16
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
assert_eq!(results, v_bf16);
// u32 -> u8
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
assert_eq!(results, v_u8);
// u32 -> i64
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_u32_bf16() {
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
fn cast_u8() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
// u8 -> f32
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
// u8 -> f16
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
assert_eq!(results, v_f16);
// u8 -> bf16
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
assert_eq!(results, v_bf16);
// u8 -> u32
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
assert_eq!(results, v_u32);
// u8 -> i64
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
assert_eq!(results, v_i64);
}
#[test]
fn it_cast_f32_bf16() {
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
fn cast_i64() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
// i64 -> f32
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
assert_eq!(results, v_f32);
assert_eq!(output, expected);
}
// i64 -> f16
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
assert_eq!(results, v_f16);
#[test]
fn it_cast_bf16_u8() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
// i64 -> bf16
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
assert_eq!(results, v_bf16);
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
// i64 -> u32
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
assert_eq!(results, v_u32);
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f16_bf16() {
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
// i64 -> u8
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
assert_eq!(results, v_u8);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {