From 42f39f16b34d1792365c3a7b0ded618f7b49da8d Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 4 Nov 2024 19:01:01 +0100 Subject: [PATCH] Add more type support for burn-jit (#2454) --- Cargo.lock | 181 ++++++---- Cargo.toml | 9 +- crates/burn-autodiff/src/runtime/mspc.rs | 4 +- crates/burn-autodiff/src/tests/abs.rs | 8 +- .../src/tests/adaptive_avgpool1d.rs | 2 +- .../src/tests/adaptive_avgpool2d.rs | 2 +- crates/burn-autodiff/src/tests/avgpool1d.rs | 2 +- crates/burn-autodiff/src/tests/avgpool2d.rs | 2 +- crates/burn-autodiff/src/tests/cat.rs | 8 +- crates/burn-autodiff/src/tests/conv1d.rs | 6 +- crates/burn-autodiff/src/tests/conv2d.rs | 6 +- crates/burn-autodiff/src/tests/conv3d.rs | 6 +- .../src/tests/conv_transpose1d.rs | 6 +- .../src/tests/conv_transpose2d.rs | 6 +- .../src/tests/conv_transpose3d.rs | 6 +- .../burn-autodiff/src/tests/deform_conv2d.rs | 2 +- crates/burn-autodiff/src/tests/mod.rs | 49 ++- crates/burn-autodiff/src/tests/sqrt.rs | 4 +- crates/burn-autodiff/src/tests/transpose.rs | 12 +- crates/burn-candle/src/lib.rs | 2 + .../burn-core/src/nn/transformer/decoder.rs | 2 +- .../burn-core/src/nn/transformer/encoder.rs | 2 +- crates/burn-cuda/Cargo.toml | 21 +- crates/burn-cuda/src/lib.rs | 3 +- .../burn-dataset/examples/speech_commands.rs | 2 +- .../burn-dataset/src/audio/speech_commands.rs | 2 +- crates/burn-dataset/src/dataset/dataframe.rs | 4 +- crates/burn-fusion/src/stream/context.rs | 62 +++- crates/burn-jit/Cargo.toml | 7 +- crates/burn-jit/src/element.rs | 18 +- .../src/fusion/elemwise/optimization.rs | 32 +- crates/burn-jit/src/fusion/kernel.rs | 337 ----------------- crates/burn-jit/src/fusion/on_write/io.rs | 276 +++++++++++++- crates/burn-jit/src/fusion/on_write/ir.rs | 40 ++- crates/burn-jit/src/fusion/on_write/kernel.rs | 271 +++++++++++++- crates/burn-jit/src/fusion/on_write/trace.rs | 71 +++- crates/burn-jit/src/kernel/binary.rs | 2 +- crates/burn-jit/src/kernel/comparison.rs | 6 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 34 +- .../kernel/conv/conv2d/transpose_direct.rs | 135 +++---- .../src/kernel/conv/conv2d/tune/conv2d.rs | 10 +- .../conv/conv2d/tune/conv_transpose2d.rs | 1 + .../src/kernel/conv/conv2d/tune/key.rs | 3 + .../src/kernel/conv/conv_transpose3d.rs | 169 ++++----- .../kernel/conv/deform_conv_transpose2d.rs | 15 +- crates/burn-jit/src/kernel/index/flip.rs | 15 +- .../burn-jit/src/kernel/index/repeat_dim.rs | 13 +- crates/burn-jit/src/kernel/index/scatter.rs | 10 +- crates/burn-jit/src/kernel/index/slice.rs | 19 +- .../burn-jit/src/kernel/index/slice_assign.rs | 27 +- .../src/kernel/interpolate/bicubic.rs | 87 ++--- .../src/kernel/interpolate/bilinear.rs | 65 ++-- .../src/kernel/interpolate/nearest.rs | 47 +-- .../kernel/interpolate/nearest_backward.rs | 57 +-- .../burn-jit/src/kernel/matmul/tune/base.rs | 2 +- crates/burn-jit/src/kernel/matmul/tune/key.rs | 17 +- .../src/kernel/pool/avg_pool2d_backward.rs | 107 +++--- .../src/kernel/pool/max_pool2d_backward.rs | 79 ++-- crates/burn-jit/src/kernel/prng/base.rs | 32 +- crates/burn-jit/src/kernel/prng/bernoulli.rs | 14 +- crates/burn-jit/src/kernel/prng/normal.rs | 9 +- crates/burn-jit/src/kernel/prng/uniform.rs | 5 +- .../src/kernel/reduce/shared/argmax.rs | 3 +- .../src/kernel/reduce/shared/argmin.rs | 3 +- .../src/kernel/reduce/shared/shader.rs | 29 +- .../burn-jit/src/kernel/reduce/tune/base.rs | 1 + crates/burn-jit/src/kernel/reduce/tune/key.rs | 10 +- crates/burn-jit/src/ops/base.rs | 8 + crates/burn-jit/src/tests/mod.rs | 40 ++- crates/burn-tensor/Cargo.toml | 21 +- crates/burn-tensor/src/lib.rs | 15 +- crates/burn-tensor/src/tensor/data.rs | 77 +++- crates/burn-tensor/src/tensor/element/base.rs | 24 ++ crates/burn-tensor/src/tensor/element/cast.rs | 44 +++ .../src/tests/activation/hard_sigmoid.rs | 2 +- .../src/tests/activation/leaky_relu.rs | 8 +- .../src/tests/activation/log_sigmoid.rs | 9 +- .../burn-tensor/src/tests/activation/mish.rs | 2 +- .../src/tests/activation/sigmoid.rs | 6 +- .../burn-tensor/src/tests/activation/silu.rs | 2 +- .../src/tests/activation/softmax.rs | 2 +- .../src/tests/activation/softmin.rs | 2 +- .../src/tests/activation/softplus.rs | 10 +- .../src/tests/activation/tanh_activation.rs | 2 +- crates/burn-tensor/src/tests/mod.rs | 339 ++++++++++++------ .../burn-tensor/src/tests/module/avgpool2d.rs | 2 +- .../src/tests/module/bicubic_interpolate.rs | 4 +- .../src/tests/module/bilinear_interpolate.rs | 4 +- .../src/tests/module/deform_conv2d.rs | 30 +- .../burn-tensor/src/tests/module/forward.rs | 4 +- .../src/tests/module/nearest_interpolate.rs | 4 +- crates/burn-tensor/src/tests/ops/arange.rs | 8 +- .../burn-tensor/src/tests/ops/arange_step.rs | 12 +- .../src/tests/ops/cartesian_grid.rs | 5 +- crates/burn-tensor/src/tests/ops/cast.rs | 4 +- crates/burn-tensor/src/tests/ops/cat.rs | 8 +- crates/burn-tensor/src/tests/ops/ceil.rs | 2 +- crates/burn-tensor/src/tests/ops/clamp.rs | 12 +- crates/burn-tensor/src/tests/ops/cos.rs | 2 +- crates/burn-tensor/src/tests/ops/div.rs | 20 +- crates/burn-tensor/src/tests/ops/erf.rs | 8 +- crates/burn-tensor/src/tests/ops/exp.rs | 4 +- crates/burn-tensor/src/tests/ops/expand.rs | 15 +- crates/burn-tensor/src/tests/ops/flatten.rs | 12 +- crates/burn-tensor/src/tests/ops/flip.rs | 10 +- crates/burn-tensor/src/tests/ops/floor.rs | 2 +- crates/burn-tensor/src/tests/ops/full.rs | 8 +- crates/burn-tensor/src/tests/ops/init.rs | 14 +- crates/burn-tensor/src/tests/ops/iter_dim.rs | 18 +- crates/burn-tensor/src/tests/ops/log.rs | 4 +- crates/burn-tensor/src/tests/ops/log1p.rs | 4 +- .../src/tests/ops/map_comparison.rs | 18 +- crates/burn-tensor/src/tests/ops/mask.rs | 42 +-- crates/burn-tensor/src/tests/ops/matmul.rs | 18 +- crates/burn-tensor/src/tests/ops/movedim.rs | 22 +- crates/burn-tensor/src/tests/ops/mul.rs | 26 +- crates/burn-tensor/src/tests/ops/neg.rs | 2 +- crates/burn-tensor/src/tests/ops/padding.rs | 31 +- crates/burn-tensor/src/tests/ops/permute.rs | 10 +- crates/burn-tensor/src/tests/ops/powf.rs | 30 +- .../burn-tensor/src/tests/ops/powf_scalar.rs | 12 +- crates/burn-tensor/src/tests/ops/random.rs | 21 +- crates/burn-tensor/src/tests/ops/recip.rs | 2 +- crates/burn-tensor/src/tests/ops/remainder.rs | 16 +- crates/burn-tensor/src/tests/ops/repeat.rs | 12 +- .../burn-tensor/src/tests/ops/repeat_dim.rs | 12 +- crates/burn-tensor/src/tests/ops/reshape.rs | 18 +- crates/burn-tensor/src/tests/ops/round.rs | 4 +- crates/burn-tensor/src/tests/ops/sin.rs | 2 +- crates/burn-tensor/src/tests/ops/slice.rs | 52 +-- .../burn-tensor/src/tests/ops/sort_argsort.rs | 67 +++- crates/burn-tensor/src/tests/ops/sqrt.rs | 4 +- crates/burn-tensor/src/tests/ops/squeeze.rs | 40 +-- crates/burn-tensor/src/tests/ops/stack.rs | 10 +- crates/burn-tensor/src/tests/ops/sub.rs | 20 +- crates/burn-tensor/src/tests/ops/tanh.rs | 2 +- crates/burn-tensor/src/tests/ops/transpose.rs | 8 +- crates/burn-tensor/src/tests/ops/tri_mask.rs | 16 +- .../src/tests/quantization/calibration.rs | 3 +- .../src/tests/quantization/ops/mask.rs | 4 +- .../src/tests/quantization/ops/powf.rs | 4 +- .../src/tests/quantization/ops/powf_scalar.rs | 4 +- .../src/tests/quantization/ops/quantize.rs | 8 +- .../src/tests/quantization/ops/reshape.rs | 2 +- .../src/tests/quantization/ops/slice.rs | 10 +- .../src/tests/quantization/scheme.rs | 8 +- crates/burn-tensor/src/tests/stats/cov.rs | 12 +- crates/burn-tensor/src/tests/stats/display.rs | 4 +- crates/burn-tensor/src/tests/stats/eye.rs | 4 +- crates/burn-wgpu/Cargo.toml | 2 + crates/burn-wgpu/src/lib.rs | 10 +- 151 files changed, 2367 insertions(+), 1557 deletions(-) delete mode 100644 crates/burn-jit/src/fusion/kernel.rs diff --git a/Cargo.lock b/Cargo.lock index 8fa3a46ca..a9cb04834 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -313,7 +313,7 @@ dependencies = [ "clap 4.5.20", "colored", "cubecl", - "derive-new", + "derive-new 0.7.0", "dirs 5.0.1", "github-device-flow", "half", @@ -321,7 +321,7 @@ dependencies = [ "os_info", "percent-encoding", "rand", - "reqwest 0.12.8", + "reqwest 0.12.9", "rstest", "serde", "serde_json", @@ -503,7 +503,7 @@ dependencies = [ "burn-common", "burn-tensor", "burn-tensor-testgen", - "derive-new", + "derive-new 0.7.0", "log", "spin", ] @@ -516,7 +516,7 @@ dependencies = [ "burn-tch", "burn-tensor", "candle-core", - "derive-new", + "derive-new 0.7.0", "half", ] @@ -529,7 +529,7 @@ dependencies = [ "getrandom", "indicatif", "rayon", - "reqwest 0.12.8", + "reqwest 0.12.9", "tokio", "web-time", ] @@ -552,7 +552,7 @@ dependencies = [ "burn-tensor", "burn-wgpu", "data-encoding", - "derive-new", + "derive-new 0.7.0", "flate2", "half", "hashbrown 0.15.0", @@ -579,9 +579,10 @@ dependencies = [ "burn-tensor", "bytemuck", "cubecl", - "derive-new", + "derive-new 0.7.0", "half", "log", + "paste", ] [[package]] @@ -590,7 +591,7 @@ version = "0.16.0" dependencies = [ "burn-common", "csv", - "derive-new", + "derive-new 0.7.0", "dirs 5.0.1", "fake", "flate2", @@ -620,7 +621,7 @@ dependencies = [ name = "burn-derive" version = "0.16.0" dependencies = [ - "derive-new", + "derive-new 0.7.0", "proc-macro2", "quote", "syn 2.0.87", @@ -632,7 +633,7 @@ version = "0.16.0" dependencies = [ "burn-common", "burn-tensor", - "derive-new", + "derive-new 0.7.0", "half", "hashbrown 0.15.0", "log", @@ -649,7 +650,7 @@ dependencies = [ "burn-tensor", "bytemuck", "cubecl", - "derive-new", + "derive-new 0.7.0", "half", "log", ] @@ -660,7 +661,7 @@ version = "0.16.0" dependencies = [ "burn", "candle-core", - "derive-new", + "derive-new 0.7.0", "half", "log", "onnx-ir", @@ -691,12 +692,13 @@ dependencies = [ "burn-tensor-testgen", "bytemuck", "cubecl", - "derive-new", + "derive-new 0.7.0", "futures-lite", "half", "hashbrown 0.15.0", "log", "num-traits", + "paste", "rand", "serde", "serial_test", @@ -713,7 +715,7 @@ dependencies = [ "burn-autodiff", "burn-common", "burn-tensor", - "derive-new", + "derive-new 0.7.0", "libm", "matrixmultiply", "ndarray 0.16.1", @@ -767,7 +769,7 @@ dependencies = [ "bytemuck", "colored", "cubecl", - "derive-new", + "derive-new 0.7.0", "half", "hashbrown 0.15.0", "num-traits", @@ -792,7 +794,7 @@ version = "0.16.0" dependencies = [ "burn-core", "burn-ndarray", - "derive-new", + "derive-new 0.7.0", "log", "nvml-wrapper", "ratatui", @@ -812,6 +814,8 @@ dependencies = [ "burn-jit", "burn-tensor", "cubecl", + "half", + "paste", ] [[package]] @@ -881,15 +885,15 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.6.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7" +checksum = "7e1a39b963e261c58017edf2007e5b63425ad21538aaaf51fe23d1da41703701" dependencies = [ "accelerate-src", "byteorder", "candle-kernels", "candle-metal-kernels", - "cudarc 0.11.9", + "cudarc", "gemm", "half", "libc", @@ -908,18 +912,18 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.6.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488" +checksum = "539cbfbf2d1d68a6ed97115e579c77c98f8ed0cfe7edbc6d7d30d2ac0c9e3d50" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.6.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f889aacd02fd999620a0435133d7cf3b58c81ef9dd5e47c38939b7a72345ea86" +checksum = "166a92826d615d98b205e346e52128fa0439f2ab3302587403fdc558b4219e19" dependencies = [ "metal 0.27.0", "once_cell", @@ -1460,7 +1464,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1476,7 +1480,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4" dependencies = [ - "derive-new", + "derive-new 0.6.0", "embassy-futures", "futures-lite", "getrandom", @@ -1491,9 +1495,9 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ - "derive-new", + "derive-new 0.6.0", "embassy-futures", "futures-lite", "getrandom", @@ -1508,13 +1512,14 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "bytemuck", "cubecl-common 0.4.0", "cubecl-macros", "cubecl-runtime 0.4.0", - "derive-new", + "derive-new 0.6.0", + "derive_more 1.0.0", "half", "log", "num-traits", @@ -1525,13 +1530,13 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "bytemuck", "cubecl-common 0.4.0", "cubecl-core", "cubecl-runtime 0.4.0", - "derive-new", + "derive-new 0.6.0", "half", "log", ] @@ -1539,15 +1544,15 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "bytemuck", "cubecl-common 0.4.0", "cubecl-core", "cubecl-cpp", "cubecl-runtime 0.4.0", - "cudarc 0.12.1", - "derive-new", + "cudarc", + "derive-new 0.6.0", "half", "log", ] @@ -1555,7 +1560,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1563,7 +1568,7 @@ dependencies = [ "cubecl-cpp", "cubecl-hip-sys", "cubecl-runtime 0.4.0", - "derive-new", + "derive-new 0.6.0", "half", "log", ] @@ -1580,7 +1585,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "bytemuck", "cubecl-core", @@ -1591,11 +1596,11 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "cubecl-common 0.4.0", "darling", - "derive-new", + "derive-new 0.6.0", "ident_case", "prettyplease", "proc-macro2", @@ -1606,7 +1611,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1628,7 +1633,7 @@ dependencies = [ "async-lock", "cfg_aliases 0.2.1", "cubecl-common 0.3.0", - "derive-new", + "derive-new 0.6.0", "dirs 5.0.1", "hashbrown 0.14.5", "log", @@ -1643,13 +1648,13 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "async-channel", "async-lock", "cfg_aliases 0.2.1", "cubecl-common 0.4.0", - "derive-new", + "derive-new 0.6.0", "dirs 5.0.1", "hashbrown 0.14.5", "log", @@ -1664,12 +1669,13 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", "cubecl-opt", "cubecl-runtime 0.4.0", + "half", "hashbrown 0.14.5", "rspirv", ] @@ -1677,7 +1683,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7" +source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" dependencies = [ "ash", "async-channel", @@ -1688,29 +1694,20 @@ dependencies = [ "cubecl-core", "cubecl-runtime 0.4.0", "cubecl-spirv", - "derive-new", + "derive-new 0.6.0", "hashbrown 0.14.5", "log", "web-time", "wgpu", ] -[[package]] -name = "cudarc" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bd4d1eee570c3b2ac64ed114125517dd1e541d88dd28fc259f1de4dba8d60" -dependencies = [ - "half", - "libloading", -] - [[package]] name = "cudarc" version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" dependencies = [ + "half", "libloading", ] @@ -1720,7 +1717,7 @@ version = "0.16.0" dependencies = [ "burn", "csv", - "reqwest 0.12.8", + "reqwest 0.12.9", "serde", ] @@ -1732,7 +1729,7 @@ dependencies = [ "burn-jit", "bytemuck", "cubecl", - "derive-new", + "derive-new 0.7.0", "log", "serde", ] @@ -1752,7 +1749,7 @@ version = "0.16.0" dependencies = [ "burn", "bytemuck", - "derive-new", + "derive-new 0.7.0", "guide", "log", "serde", @@ -1764,7 +1761,7 @@ version = "0.16.0" dependencies = [ "burn", "bytemuck", - "derive-new", + "derive-new 0.7.0", "guide", "log", "serde", @@ -1777,7 +1774,7 @@ dependencies = [ "burn", "bytemuck", "cubecl", - "derive-new", + "derive-new 0.7.0", "log", "serde", ] @@ -1874,6 +1871,17 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "derive_arbitrary" version = "1.3.2" @@ -1927,6 +1935,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "unicode-xid", +] + [[package]] name = "deunicode" version = "1.6.0" @@ -2224,9 +2253,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" +checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb" dependencies = [ "simd-adler32", ] @@ -3191,9 +3220,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", @@ -5669,9 +5698,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", @@ -5843,9 +5872,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" dependencies = [ "bitflags 2.6.0", "errno", @@ -5856,9 +5885,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.15" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ "log", "once_cell", @@ -5970,9 +5999,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.2" +version = "2.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2c1f7fc6deb21665a9060dfc7d271be784669295a31babdcd4dd2c79ae8cbfb" +checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8" dependencies = [ "sdd", ] @@ -6649,7 +6678,7 @@ name = "text-classification" version = "0.16.0" dependencies = [ "burn", - "derive-new", + "derive-new 0.7.0", "serde", "tokenizers", ] @@ -6659,7 +6688,7 @@ name = "text-generation" version = "0.16.0" dependencies = [ "burn", - "derive-new", + "derive-new 0.7.0", "log", "serde", "tokenizers", @@ -6946,7 +6975,7 @@ checksum = "4126466aafe1c518cb5c23979c286903cb1d1ff1bc3b76891254a243a0ed1e15" dependencies = [ "anyhow", "clap 4.5.20", - "derive_more", + "derive_more 0.99.18", "env_logger", "log", "rand", diff --git a/Cargo.toml b/Cargo.toml index e0461f8b8..81ada8b70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ version = "0.16.0" [workspace.dependencies] atomic_float = "1" bytemuck = "1.19.0" -candle-core = { version = "0.6.0" } +candle-core = { version = "0.7" } clap = { version = "4.5.20", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" @@ -53,6 +53,7 @@ js-sys = "0.3.69" libm = "0.2.9" log = { default-features = false, version = "0.4.22" } md5 = "0.7.0" +paste = "1" percent-encoding = "2.3.1" polars = { version = "0.41.3", features = ["lazy"] } pretty_assertions = "1.4.1" @@ -117,7 +118,7 @@ bincode = { version = "2.0.0-rc.3", features = [ # # The following packages disable the "std" feature for no_std compatibility # -derive-new = { version = "0.6.0", default-features = false } +derive-new = { version = "0.7.0", default-features = false } blas-src = { version = "0.10.0", default-features = false } half = { version = "2.4.1", features = [ @@ -152,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-autodiff/src/runtime/mspc.rs b/crates/burn-autodiff/src/runtime/mspc.rs index 46d54ddbe..c128f34b8 100644 --- a/crates/burn-autodiff/src/runtime/mspc.rs +++ b/crates/burn-autodiff/src/runtime/mspc.rs @@ -73,9 +73,9 @@ impl AutodiffClient for ChannelClient { .unwrap() } - fn backward(&self, root: AutodiffTensor) -> Gradients { + fn backward(&self, root: AutodiffTensor) -> Gradients { let node_id = root.node.id; - let grads = Gradients::new::(root.node, root.primitive); + let grads = Gradients::new::(root.node, root.primitive); let (callback, receiver) = std::sync::mpsc::channel(); self.sender diff --git a/crates/burn-autodiff/src/tests/abs.rs b/crates/burn-autodiff/src/tests/abs.rs index be754be5b..41cb43d9f 100644 --- a/crates/burn-autodiff/src/tests/abs.rs +++ b/crates/burn-autodiff/src/tests/abs.rs @@ -20,10 +20,10 @@ mod tests { let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]); - grad_1.to_data().assert_approx_eq(&expected, 3); + grad_1.to_data().assert_approx_eq(&expected, 5); let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]); - grad_2.to_data().assert_approx_eq(&expected, 3); + grad_2.to_data().assert_approx_eq(&expected, 5); } #[test] @@ -42,10 +42,10 @@ mod tests { let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]); - grad_1.to_data().assert_approx_eq(&expected, 3); + grad_1.to_data().assert_approx_eq(&expected, 5); let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]); - grad_2.to_data().assert_approx_eq(&expected, 3); + grad_2.to_data().assert_approx_eq(&expected, 5); let contains_nan = grad_2.contains_nan(); assert_eq!(contains_nan.into_scalar(), false); diff --git a/crates/burn-autodiff/src/tests/adaptive_avgpool1d.rs b/crates/burn-autodiff/src/tests/adaptive_avgpool1d.rs index 48994ebcb..9e4fb7c01 100644 --- a/crates/burn-autodiff/src/tests/adaptive_avgpool1d.rs +++ b/crates/burn-autodiff/src/tests/adaptive_avgpool1d.rs @@ -46,7 +46,7 @@ mod tests { x_grad .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + .assert_approx_eq(&x_grad_actual.into_data(), 4); } } } diff --git a/crates/burn-autodiff/src/tests/adaptive_avgpool2d.rs b/crates/burn-autodiff/src/tests/adaptive_avgpool2d.rs index d87e8c00c..b984140a2 100644 --- a/crates/burn-autodiff/src/tests/adaptive_avgpool2d.rs +++ b/crates/burn-autodiff/src/tests/adaptive_avgpool2d.rs @@ -62,7 +62,7 @@ mod tests { x_grad .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + .assert_approx_eq(&x_grad_actual.into_data(), 4); } } } diff --git a/crates/burn-autodiff/src/tests/avgpool1d.rs b/crates/burn-autodiff/src/tests/avgpool1d.rs index d4ac5d269..3f5d81d03 100644 --- a/crates/burn-autodiff/src/tests/avgpool1d.rs +++ b/crates/burn-autodiff/src/tests/avgpool1d.rs @@ -97,7 +97,7 @@ mod tests { x_grad .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + .assert_approx_eq(&x_grad_actual.into_data(), 4); } } } diff --git a/crates/burn-autodiff/src/tests/avgpool2d.rs b/crates/burn-autodiff/src/tests/avgpool2d.rs index 902b3af26..310098899 100644 --- a/crates/burn-autodiff/src/tests/avgpool2d.rs +++ b/crates/burn-autodiff/src/tests/avgpool2d.rs @@ -124,7 +124,7 @@ mod tests { x_grad .to_data() - .assert_approx_eq(&x_grad_actual.into_data(), 3); + .assert_approx_eq(&x_grad_actual.into_data(), 4); } } } diff --git a/crates/burn-autodiff/src/tests/cat.rs b/crates/burn-autodiff/src/tests/cat.rs index 86c72e924..18777c079 100644 --- a/crates/burn-autodiff/src/tests/cat.rs +++ b/crates/burn-autodiff/src/tests/cat.rs @@ -40,21 +40,21 @@ mod tests { .clone() .slice([0..1]) .to_data() - .assert_approx_eq(&grad_1_slice_1.to_data(), 3); + .assert_approx_eq(&grad_1_slice_1.to_data(), 5); grad_1 .slice([1..2]) .to_data() - .assert_approx_eq(&grad_1_slice_2.to_data(), 3); + .assert_approx_eq(&grad_1_slice_2.to_data(), 5); grad_2 .clone() .slice([0..1]) .to_data() - .assert_approx_eq(&grad_2_slice_1.to_data(), 3); + .assert_approx_eq(&grad_2_slice_1.to_data(), 5); grad_2 .slice([1..2]) .to_data() - .assert_approx_eq(&grad_2_slice_2.to_data(), 3); + .assert_approx_eq(&grad_2_slice_2.to_data(), 5); } #[test] diff --git a/crates/burn-autodiff/src/tests/conv1d.rs b/crates/burn-autodiff/src/tests/conv1d.rs index 570b3e500..adba2885a 100644 --- a/crates/burn-autodiff/src/tests/conv1d.rs +++ b/crates/burn-autodiff/src/tests/conv1d.rs @@ -265,15 +265,15 @@ mod tests { expected_grads .bias .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); + .assert_approx_eq(&bias_grad_actual.to_data(), 5); expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq(&weight_grad_actual.to_data(), 5); expected_grads .x .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + .assert_approx_eq(&x_grad_actual.to_data(), 5); } } } diff --git a/crates/burn-autodiff/src/tests/conv2d.rs b/crates/burn-autodiff/src/tests/conv2d.rs index 6351936e7..3a095de7b 100644 --- a/crates/burn-autodiff/src/tests/conv2d.rs +++ b/crates/burn-autodiff/src/tests/conv2d.rs @@ -883,15 +883,15 @@ mod tests { expected_grads .bias .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); + .assert_approx_eq(&bias_grad_actual.to_data(), 5); expected_grads .x .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + .assert_approx_eq(&x_grad_actual.to_data(), 5); expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq(&weight_grad_actual.to_data(), 5); } } } diff --git a/crates/burn-autodiff/src/tests/conv3d.rs b/crates/burn-autodiff/src/tests/conv3d.rs index 5d0b29088..7dfb3f686 100644 --- a/crates/burn-autodiff/src/tests/conv3d.rs +++ b/crates/burn-autodiff/src/tests/conv3d.rs @@ -512,15 +512,15 @@ mod tests { expected_grads .bias .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); + .assert_approx_eq(&bias_grad_actual.to_data(), 5); expected_grads .x .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + .assert_approx_eq(&x_grad_actual.to_data(), 5); expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq(&weight_grad_actual.to_data(), 5); } } } diff --git a/crates/burn-autodiff/src/tests/conv_transpose1d.rs b/crates/burn-autodiff/src/tests/conv_transpose1d.rs index ae8e2840e..db1a17fcb 100644 --- a/crates/burn-autodiff/src/tests/conv_transpose1d.rs +++ b/crates/burn-autodiff/src/tests/conv_transpose1d.rs @@ -281,15 +281,15 @@ mod tests { expected_grads .bias .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); + .assert_approx_eq(&bias_grad_actual.to_data(), 5); expected_grads .x .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + .assert_approx_eq(&x_grad_actual.to_data(), 5); expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq(&weight_grad_actual.to_data(), 5); } } } diff --git a/crates/burn-autodiff/src/tests/conv_transpose2d.rs b/crates/burn-autodiff/src/tests/conv_transpose2d.rs index fedf1b472..94201074d 100644 --- a/crates/burn-autodiff/src/tests/conv_transpose2d.rs +++ b/crates/burn-autodiff/src/tests/conv_transpose2d.rs @@ -694,15 +694,15 @@ mod tests { expected_grads .bias .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); + .assert_approx_eq(&bias_grad_actual.to_data(), 5); expected_grads .x .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + .assert_approx_eq(&x_grad_actual.to_data(), 5); expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq(&weight_grad_actual.to_data(), 5); } } } diff --git a/crates/burn-autodiff/src/tests/conv_transpose3d.rs b/crates/burn-autodiff/src/tests/conv_transpose3d.rs index 49575e2f3..553abc891 100644 --- a/crates/burn-autodiff/src/tests/conv_transpose3d.rs +++ b/crates/burn-autodiff/src/tests/conv_transpose3d.rs @@ -567,15 +567,15 @@ mod tests { expected_grads .bias .to_data() - .assert_approx_eq(&bias_grad_actual.to_data(), 3); + .assert_approx_eq(&bias_grad_actual.to_data(), 5); expected_grads .x .to_data() - .assert_approx_eq(&x_grad_actual.to_data(), 3); + .assert_approx_eq(&x_grad_actual.to_data(), 5); expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq(&weight_grad_actual.to_data(), 5); } } } diff --git a/crates/burn-autodiff/src/tests/deform_conv2d.rs b/crates/burn-autodiff/src/tests/deform_conv2d.rs index df41cdb81..ed576a10d 100644 --- a/crates/burn-autodiff/src/tests/deform_conv2d.rs +++ b/crates/burn-autodiff/src/tests/deform_conv2d.rs @@ -1408,7 +1408,7 @@ mod tests { expected_grads .weight .to_data() - .assert_approx_eq(&weight_grad_actual.to_data(), 3); + .assert_approx_eq_diff(&weight_grad_actual.to_data(), 0.04); } } } diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 6fdafc1be..70728394c 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -67,11 +67,54 @@ mod transpose; #[macro_export] macro_rules! testgen_all { + // Avoid using paste dependency with no parameters () => { - type TestAutodiffBackend = burn_autodiff::Autodiff; - type TestAutodiffTensor = burn_tensor::Tensor; + mod autodiff { + pub use super::*; + type TestAutodiffBackend = burn_autodiff::Autodiff; + type TestAutodiffTensor = burn_tensor::Tensor; - // Behavior + pub type FloatType = ::FloatElem; + pub type IntType = ::IntElem; + pub type BoolType = ::BoolTensorPrimitive; + + $crate::testgen_with_float_param!(); + } + }; + ([$($float:ident),*]) => { + mod autodiff { + pub use super::*; + type TestAutodiffBackend = burn_autodiff::Autodiff; + type TestAutodiffTensor = burn_tensor::Tensor; + + pub type FloatType = ::FloatElem; + pub type IntType = ::IntElem; + pub type BoolType = ::BoolTensorPrimitive; + + ::paste::paste! { + $(mod [<$float _ty>] { + pub use super::*; + + pub type TestBackend = TestBackend2<$float, IntType>; + pub type TestAutodiffBackend = burn_autodiff::Autodiff; + pub type TestAutodiffTensor = burn_tensor::Tensor; + pub type TestTensor = TestTensor2<$float, IntType, D>; + pub type TestTensorInt = TestTensorInt2<$float, IntType, D>; + pub type TestTensorBool = TestTensorBool2<$float, IntType, D>; + + type FloatType = $float; + + $crate::testgen_with_float_param!(); + })* + } + } + }; +} + +#[macro_export] +macro_rules! testgen_with_float_param { + () => { + // Behaviour burn_autodiff::testgen_ad_broadcast!(); burn_autodiff::testgen_gradients!(); burn_autodiff::testgen_bridge!(); diff --git a/crates/burn-autodiff/src/tests/sqrt.rs b/crates/burn-autodiff/src/tests/sqrt.rs index 1ffd94b78..daf7a8a13 100644 --- a/crates/burn-autodiff/src/tests/sqrt.rs +++ b/crates/burn-autodiff/src/tests/sqrt.rs @@ -20,9 +20,9 @@ mod tests { let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[82.1126, 99.0832], [82.1126, 99.0832]]); - grad_1.to_data().assert_approx_eq(&expected, 3); + grad_1.to_data().assert_approx_eq_diff(&expected, 0.02); let expected = TensorData::from([[30.3093, 33.1204], [34.5819, 38.7694]]); - grad_2.to_data().assert_approx_eq(&expected, 3); + grad_2.to_data().assert_approx_eq(&expected, 2); } } diff --git a/crates/burn-autodiff/src/tests/transpose.rs b/crates/burn-autodiff/src/tests/transpose.rs index b87962151..5219e8b9f 100644 --- a/crates/burn-autodiff/src/tests/transpose.rs +++ b/crates/burn-autodiff/src/tests/transpose.rs @@ -21,10 +21,10 @@ mod tests { grad_1 .to_data() - .assert_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), false); + .assert_approx_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), 3); grad_2 .to_data() - .assert_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), false); + .assert_approx_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), 3); } #[test] @@ -48,13 +48,13 @@ mod tests { let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); - grad_1.to_data().assert_eq( + grad_1.to_data().assert_approx_eq( &TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]), - false, + 3, ); - grad_2.to_data().assert_eq( + grad_2.to_data().assert_approx_eq( &TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]), - false, + 3, ); } } diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index 8f0d496de..67586b071 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -34,6 +34,8 @@ mod tests { type TestAutodiffBackend = burn_autodiff::Autodiff; type TestAutodiffTensor = burn_tensor::Tensor; + pub type FloatType = f32; + // test activation burn_tensor::testgen_gelu!(); burn_tensor::testgen_prelu!(); diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index 7784972c7..c79c21a78 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -523,7 +523,7 @@ mod tests { // Should produce the same tokens. output_1 .into_data() - .assert_approx_eq(&output_2.into_data(), 3); + .assert_approx_eq(&output_2.into_data(), 2); } #[test] diff --git a/crates/burn-core/src/nn/transformer/encoder.rs b/crates/burn-core/src/nn/transformer/encoder.rs index 6aea721d3..03521bc2b 100644 --- a/crates/burn-core/src/nn/transformer/encoder.rs +++ b/crates/burn-core/src/nn/transformer/encoder.rs @@ -441,7 +441,7 @@ mod tests { output_1 .into_data() - .assert_approx_eq(&output_2.into_data(), 3); + .assert_approx_eq(&output_2.into_data(), 2); } #[test] diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 4a5456e1a..708f16227 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -2,38 +2,43 @@ authors = ["nathanielsimard "] categories = ["science"] description = "CUDA backend for the Burn framework" +documentation = "https://docs.rs/burn-cuda" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu", "cuda"] license.workspace = true name = "burn-cuda" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda" -documentation = "https://docs.rs/burn-cuda" version.workspace = true [features] -default = ["fusion", "burn-jit/default", "cubecl/default"] -fusion = ["burn-fusion", "burn-jit/fusion"] autotune = ["burn-jit/autotune"] +default = ["fusion", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] +fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] [dependencies] -cubecl = { workspace = true, features = ["cuda"] } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = ["cubecl-cuda"] } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ + "cubecl-cuda", +] } +cubecl = { workspace = true, features = ["cuda"] } -half = { workspace = true } bytemuck = { workspace = true } +half = { workspace = true } -log = { workspace = true } derive-new = { workspace = true } +log = { workspace = true } + [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ "export_tests", ] } +paste = { workspace = true } + [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index a3f9a0218..086d00bab 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -17,6 +17,7 @@ mod tests { use burn_jit::JitBackend; pub type TestRuntime = cubecl::cuda::CudaRuntime; + pub use half::{bf16, f16}; - burn_jit::testgen_all!(); + burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64]); } diff --git a/crates/burn-dataset/examples/speech_commands.rs b/crates/burn-dataset/examples/speech_commands.rs index cce7f131e..e54e9dff5 100644 --- a/crates/burn-dataset/examples/speech_commands.rs +++ b/crates/burn-dataset/examples/speech_commands.rs @@ -9,7 +9,7 @@ fn speech_command() { println!("Item: {:?}", item); println!("Item Length: {:?}", item.audio_samples.len()); - println!("Label: {}", item.label.to_string()); + println!("Label: {}", item.label); assert_eq!(test.len(), 4890); assert_eq!(item.label.to_string(), "Yes"); diff --git a/crates/burn-dataset/src/audio/speech_commands.rs b/crates/burn-dataset/src/audio/speech_commands.rs index 28c2d34f2..77a4a4a59 100644 --- a/crates/burn-dataset/src/audio/speech_commands.rs +++ b/crates/burn-dataset/src/audio/speech_commands.rs @@ -5,7 +5,7 @@ use crate::{ use hound::WavReader; use serde::{Deserialize, Serialize}; -use strum::EnumCount; +use strum::EnumCount as _; use strum_macros::{Display, EnumCount, FromRepr}; type MappedDataset = MapperDataset, ConvertSamples, SpeechItemRaw>; diff --git a/crates/burn-dataset/src/dataset/dataframe.rs b/crates/burn-dataset/src/dataset/dataframe.rs index 283726648..4cfce856b 100644 --- a/crates/burn-dataset/src/dataset/dataframe.rs +++ b/crates/burn-dataset/src/dataset/dataframe.rs @@ -367,7 +367,7 @@ mod tests { let dataset: DataframeDataset = DataframeDataset::new(df).unwrap(); assert_eq!(dataset.len(), 3); - assert_eq!(dataset.is_empty(), false); + assert!(!dataset.is_empty()); let item = dataset.get(1).unwrap(); assert_eq!( @@ -439,7 +439,7 @@ mod tests { let dataset = DataframeDataset::::new(df).unwrap(); assert_eq!(dataset.len(), 3); - assert_eq!(dataset.is_empty(), false); + assert!(!dataset.is_empty()); let item = dataset.get(1).unwrap(); assert_eq!( diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index b6b8886db..277b07d01 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -8,6 +8,7 @@ use hashbrown::HashMap; /// /// It also contains all scalar values, which can change even for the same graph. They are sorted /// in the order in which they appear in the graph. +#[allow(clippy::too_many_arguments)] #[derive(new)] pub struct Context<'a, H> { /// The tensor mapping where local tensor id points to the updated tensor description. @@ -20,8 +21,22 @@ pub struct Context<'a, H> { pub scalar_f16: &'a Vec, /// BF16 scalars found in the graph in the order they appeared. pub scalar_bf16: &'a Vec, - /// Int scalars found in the graph in the order they appeared. - pub scalar_ints: &'a Vec, + /// i64 scalars found in the graph in the order they appeared. + pub scalar_i64: &'a Vec, + /// i32 scalars found in the graph in the order they appeared. + pub scalar_i32: &'a Vec, + /// i16 scalars found in the graph in the order they appeared. + pub scalar_i16: &'a Vec, + /// i8 scalars found in the graph in the order they appeared. + pub scalar_i8: &'a Vec, + /// u64 scalars found in the graph in the order they appeared. + pub scalar_u64: &'a Vec, + /// u32 scalars found in the graph in the order they appeared. + pub scalar_u32: &'a Vec, + /// u16 scalars found in the graph in the order they appeared. + pub scalar_u16: &'a Vec, + /// u8 scalars found in the graph in the order they appeared. + pub scalar_u8: &'a Vec, } #[derive(Default)] @@ -34,7 +49,14 @@ pub(crate) struct OperationConverter { scalar_f32: Vec, scalar_f16: Vec, scalar_bf16: Vec, - scalar_ints: Vec, + scalar_i64: Vec, + scalar_i32: Vec, + scalar_i16: Vec, + scalar_i8: Vec, + scalar_u64: Vec, + scalar_u32: Vec, + scalar_u16: Vec, + scalar_u8: Vec, } pub(crate) trait RelativeOps { @@ -63,7 +85,14 @@ impl OperationConverter { scalar_f32: &self.scalar_f32, scalar_f16: &self.scalar_f16, scalar_bf16: &self.scalar_bf16, - scalar_ints: &self.scalar_ints, + scalar_i64: &self.scalar_i64, + scalar_i32: &self.scalar_i32, + scalar_i16: &self.scalar_i16, + scalar_i8: &self.scalar_i8, + scalar_u64: &self.scalar_u64, + scalar_u32: &self.scalar_u32, + scalar_u16: &self.scalar_u16, + scalar_u8: &self.scalar_u8, } } @@ -74,7 +103,14 @@ impl OperationConverter { self.scalar_f32.clear(); self.scalar_f16.clear(); self.scalar_bf16.clear(); - self.scalar_ints.clear(); + self.scalar_i64.clear(); + self.scalar_i32.clear(); + self.scalar_i16.clear(); + self.scalar_i8.clear(); + self.scalar_u64.clear(); + self.scalar_u32.clear(); + self.scalar_u16.clear(); + self.scalar_u8.clear(); } pub(crate) fn relative_float(&mut self, elem: &E, dtype: &DType) -> E { @@ -90,8 +126,18 @@ impl OperationConverter { 0.elem() } - pub(crate) fn relative_int(&mut self, elem: &E) -> E { - self.scalar_ints.push(elem.elem()); + pub(crate) fn relative_int(&mut self, elem: &E, dtype: &DType) -> E { + match dtype { + DType::I64 => self.scalar_i64.push(elem.elem()), + DType::I32 => self.scalar_i32.push(elem.elem()), + DType::I16 => self.scalar_i16.push(elem.elem()), + DType::I8 => self.scalar_i8.push(elem.elem()), + DType::U64 => self.scalar_u64.push(elem.elem()), + DType::U32 => self.scalar_u32.push(elem.elem()), + DType::U16 => self.scalar_u16.push(elem.elem()), + DType::U8 => self.scalar_u8.push(elem.elem()), + _ => todo!("Unsupported"), + } // We return 0 so that the id from a scalar operation is the same no matter its scalar // value. 0.elem() @@ -116,7 +162,7 @@ impl RelativeOps for OperationDescription { ), OperationDescription::NumericInt(dtype, ops) => OperationDescription::NumericInt( *dtype, - ops.to_relative(converter, |converter, e| converter.relative_int(e)), + ops.to_relative(converter, |converter, e| converter.relative_int(e, dtype)), ), OperationDescription::Bool(ops) => { OperationDescription::Bool(ops.to_relative(converter)) diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 716ba6a82..ce5b22b8a 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -2,13 +2,13 @@ authors = ["nathanielsimard "] categories = ["science"] description = "Generic backend that can be compiled just-in-time to any shader language target" +documentation = "https://docs.rs/burn-jit" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu"] license.workspace = true name = "burn-jit" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit" -documentation = "https://docs.rs/burn-jit" version.workspace = true [features] @@ -22,6 +22,7 @@ export_tests = [ "burn-tensor/export_tests", "burn-ndarray", "fusion", + "paste", ] fusion = ["burn-fusion"] std = ["cubecl/std"] @@ -31,7 +32,8 @@ template = [] burn-common = { path = "../burn-common", version = "0.16.0" } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ - "cubecl", "repr", + "cubecl", + "repr", ] } cubecl = { workspace = true, features = ["linalg"] } @@ -56,6 +58,7 @@ hashbrown = { workspace = true } # When exporting tests burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, optional = true } burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true } +paste = { workspace = true, optional = true } serial_test = { workspace = true, optional = true } [package.metadata.docs.rs] diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index 64e9ce029..939b2fb24 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -1,4 +1,5 @@ use cubecl::{ + flex32, prelude::{Float, Int, Numeric}, CubeElement, }; @@ -12,17 +13,26 @@ pub trait FloatElement: JitElement + Float {} /// The int element type for the jit backend. pub trait IntElement: JitElement + Int {} +impl JitElement for u64 {} impl JitElement for u32 {} - +impl JitElement for u16 {} +impl JitElement for u8 {} +impl JitElement for i64 {} impl JitElement for i32 {} - +impl JitElement for i16 {} +impl JitElement for i8 {} +impl JitElement for f64 {} impl JitElement for f32 {} - +impl JitElement for flex32 {} impl JitElement for half::f16 {} - impl JitElement for half::bf16 {} +impl FloatElement for f64 {} impl FloatElement for f32 {} +impl FloatElement for flex32 {} impl FloatElement for half::bf16 {} impl FloatElement for half::f16 {} +impl IntElement for i64 {} impl IntElement for i32 {} +impl IntElement for i16 {} +impl IntElement for i8 {} diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index 70a34e579..f5f300092 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -68,15 +68,29 @@ impl TraceRunner for ElemwiseOptimization { Arg::Input(index, precision, _) => match precision { ElemwisePrecision::F32 => inputs.t_f32.values.get(index as usize), ElemwisePrecision::F16 => inputs.t_f16.values.get(index as usize), + ElemwisePrecision::BF16 => inputs.t_bf16.values.get(index as usize), + ElemwisePrecision::U64 => inputs.t_u64.values.get(index as usize), ElemwisePrecision::U32 => inputs.t_u32.values.get(index as usize), + ElemwisePrecision::U16 => inputs.t_u16.values.get(index as usize), + ElemwisePrecision::U8 => inputs.t_u8.values.get(index as usize), + ElemwisePrecision::I64 => inputs.t_i64.values.get(index as usize), ElemwisePrecision::I32 => inputs.t_i32.values.get(index as usize), + ElemwisePrecision::I16 => inputs.t_i16.values.get(index as usize), + ElemwisePrecision::I8 => inputs.t_i8.values.get(index as usize), _ => panic!("Invalid value"), }, Arg::Output(index, precision, _) => match precision { ElemwisePrecision::F32 => outputs.t_f32.values.get(index as usize), ElemwisePrecision::F16 => outputs.t_f16.values.get(index as usize), + ElemwisePrecision::BF16 => outputs.t_bf16.values.get(index as usize), + ElemwisePrecision::U64 => outputs.t_u64.values.get(index as usize), ElemwisePrecision::U32 => outputs.t_u32.values.get(index as usize), + ElemwisePrecision::U16 => outputs.t_u16.values.get(index as usize), + ElemwisePrecision::U8 => outputs.t_u8.values.get(index as usize), + ElemwisePrecision::I64 => outputs.t_i64.values.get(index as usize), ElemwisePrecision::I32 => outputs.t_i32.values.get(index as usize), + ElemwisePrecision::I16 => outputs.t_i16.values.get(index as usize), + ElemwisePrecision::I8 => outputs.t_i8.values.get(index as usize), _ => panic!("Invalid value"), }, _ => panic!("Invalid value"), @@ -142,11 +156,11 @@ impl TraceRunner for ElemwiseOptimization { let mut output = u8::MAX; for (handle, tensor) in handles_inputs.zip(inputs) { - output = u8::min(vectorization_input(handle, tensor), output); + output = Ord::min(vectorization_input(handle, tensor), output); } for tensor in outputs { - output = u8::min(vectorization_output(tensor), output); + output = Ord::min(vectorization_output(tensor), output); } output @@ -168,15 +182,29 @@ fn elemwise_fuse( Arg::Input(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => inputs.t_f32.index(index).len(), ElemwisePrecision::F16 => inputs.t_f16.index(index).len(), + ElemwisePrecision::BF16 => inputs.t_bf16.index(index).len(), + ElemwisePrecision::U64 => inputs.t_u64.index(index).len(), ElemwisePrecision::U32 => inputs.t_u32.index(index).len(), + ElemwisePrecision::U16 => inputs.t_u16.index(index).len(), + ElemwisePrecision::U8 => inputs.t_u8.index(index).len(), + ElemwisePrecision::I64 => inputs.t_i64.index(index).len(), ElemwisePrecision::I32 => inputs.t_i32.index(index).len(), + ElemwisePrecision::I16 => inputs.t_i16.index(index).len(), + ElemwisePrecision::I8 => inputs.t_i8.index(index).len(), _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Output(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => outputs.t_f32.index(index).len(), ElemwisePrecision::F16 => outputs.t_f16.index(index).len(), + ElemwisePrecision::BF16 => outputs.t_bf16.index(index).len(), + ElemwisePrecision::U64 => outputs.t_u64.index(index).len(), ElemwisePrecision::U32 => outputs.t_u32.index(index).len(), + ElemwisePrecision::U16 => outputs.t_u16.index(index).len(), + ElemwisePrecision::U8 => outputs.t_u8.index(index).len(), + ElemwisePrecision::I64 => outputs.t_i64.index(index).len(), ElemwisePrecision::I32 => outputs.t_i32.index(index).len(), + ElemwisePrecision::I16 => outputs.t_i16.index(index).len(), + ElemwisePrecision::I8 => outputs.t_i8.index(index).len(), _ => comptime![panic!("Unsupported precision {precision:?}")], }, _ => comptime![panic!("Invalid ref layout.")], diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs deleted file mode 100644 index b51308d24..000000000 --- a/crates/burn-jit/src/fusion/kernel.rs +++ /dev/null @@ -1,337 +0,0 @@ -use cubecl::calculate_num_elems_dyn_rank; -use cubecl::prelude::*; - -use crate::fusion::strides_dyn_rank; -use crate::fusion::JitFusionHandle; -use crate::kernel::Kernel; -use crate::JitRuntime; -use burn_fusion::stream::Context; -use burn_tensor::repr::TensorDescription; -use burn_tensor::repr::TensorStatus; -use cubecl::client::ComputeClient; -use cubecl::server::Binding; -use cubecl::tune::AutotuneOperation; -use std::marker::PhantomData; -use std::sync::Arc; - -use super::tracing::ExecutionInfo; - -#[derive(new)] -pub struct FusionKernel { - id: u64, // Same ID for all different settings. - info: Arc, - settings: KernelSettings, - runtime_info: Vec, - cube_count: CubeCount, - _runtime: PhantomData, -} - -pub trait FusionKernelFactory { - /// Create a new kernel. - fn create( - &self, - handles_inputs: &[JitFusionHandle], - inputs: &[&TensorDescription], - outputs: &[&TensorDescription], - stateful: bool, // Should be set to false when running autotune. - ) -> FusionKernel; -} - -/// An instantiation of a [kernel](Kernel) that can be executed. -#[derive(new)] -pub struct ExecutableKernel { - kernel: Box>, - cube_count: CubeCount, - bindings: Vec>, - client: ComputeClient, -} - -/// An instantiation of a [kernel](Kernel) that can be autotuned. -/// -/// The main difference with an [executable kernel](ExecutableKernel) is that this kernel can be -/// cloned and executed multiple times to properly collect benchmarks. -/// -/// The clone function used is defined in the trait [AutotuneOperation] instead of [Clone]. -#[derive(new)] -pub struct AutotunableKernel { - kernel: Arc>, - count: CubeCount, - bindings: Vec>, - client: ComputeClient, -} - -// Information related to the output of this kernel. -#[derive(Debug)] -pub enum OutputRuntimeInfo { - Inplace { input_index: usize }, - Array { size: usize }, -} - -impl ExecutableKernel { - /// Execute the kernel. - pub fn execute(self) { - unsafe { - self.client - .execute_unchecked(self.kernel, self.cube_count, self.bindings) - } - } -} - -impl AutotuneOperation for AutotunableKernel { - fn execute(self: Box) { - self.client - .execute(Box::new(self.kernel), self.count, self.bindings) - } - - fn clone(&self) -> Box { - Box::new(Self { - kernel: self.kernel.clone(), - count: self.count.clone(), - bindings: self.bindings.clone(), - client: self.client.clone(), - }) - } -} - -impl From> for AutotunableKernel { - fn from(value: ExecutableKernel) -> Self { - Self { - kernel: Arc::new(value.kernel), - count: value.cube_count.clone(), - bindings: value.bindings, - client: value.client, - } - } -} - -impl FusionKernel { - pub fn create( - factory: &K, - running_info: &ExecutionInfo<'_>, - context: &mut Context<'_, JitFusionHandle>, - device: R::Device, - client: ComputeClient, - stateful: bool, - ) -> ExecutableKernel - where - K: FusionKernelFactory, - { - let (handles_input, inputs_description_updated, outputs_description_updated) = - process_inputs_outputs( - &running_info.inputs, - &running_info.outputs, - context, - stateful, - ); - - let fusion_kernel = factory.create( - &handles_input, - &inputs_description_updated, - &outputs_description_updated, - stateful, - ); - - let rank_input = running_info - .inputs - .first() - .map(|desc| desc.shape.len()) - .unwrap_or(1); - let rank_output = running_info - .outputs - .first() - .map(|desc| desc.shape.len()) - .unwrap_or(1); - let rank = usize::max(rank_input, rank_output); - - let num_tensors = running_info.inputs.len() + running_info.outputs.len(); - // The buffer starts with the rank, then each tensor shape and stride. - let info_size = (num_tensors * rank * 2) + 1; - - let mut num_handles = num_tensors + 1; - if running_info.scalars.num_f32 > 0 { - num_handles += 1; - } - if running_info.scalars.num_f16 > 0 { - num_handles += 1; - } - if running_info.scalars.num_bf16 > 0 { - num_handles += 1; - } - if running_info.scalars.num_int > 0 { - num_handles += 1; - } - - let mut info = Vec::with_capacity(info_size); - let mut bindings = Vec::with_capacity(num_handles); - let mut output_register = Vec::with_capacity(outputs_description_updated.len()); - - // We register the info and handles for the inputs. - for (handle, tensor) in handles_input.iter().zip(inputs_description_updated.iter()) { - register_info_tensor(&mut info, tensor, handle); - bindings.push(handle.handle.clone().binding()); - } - - // We register the info and handles for the outputs. - for (tensor, output_info) in outputs_description_updated - .iter() - .zip(fusion_kernel.runtime_info.iter()) - { - match output_info { - // Use the input inplace for this output. - OutputRuntimeInfo::Inplace { input_index } => { - let input = handles_input.get(*input_index).unwrap(); - - let handle_fusion = JitFusionHandle { - client: client.clone(), - device: device.clone(), - strides: strides_dyn_rank(&tensor.shape), - handle: input.handle.clone(), - }; - output_register.push((tensor.id, handle_fusion)); - } - // Create a new buffer for this output. - OutputRuntimeInfo::Array { size } => { - let handle_fusion = JitFusionHandle { - client: client.clone(), - device: device.clone(), - strides: strides_dyn_rank(&tensor.shape), - handle: client.empty(*size), - }; - - register_info_tensor(&mut info, tensor, &handle_fusion); - bindings.push(handle_fusion.handle.clone().binding()); - output_register.push((tensor.id, handle_fusion)); - } - }; - } - - // [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len] - if R::require_array_lengths() { - for input in inputs_description_updated.iter() { - let len = calculate_num_elems_dyn_rank(&input.shape); - info.push(len as u32); - } - - for output in outputs_description_updated.iter() { - let len = calculate_num_elems_dyn_rank(&output.shape); - info.push(len as u32); - } - } - - // Create the info buffer. - bindings.push(client.create(bytemuck::cast_slice(&info)).binding()); - - // Finally we finish with the named bindings. - if running_info.scalars.num_f32 > 0 { - let bytes = bytemuck::cast_slice(&context.scalar_f32[0..running_info.scalars.num_f32]); - bindings.push(client.create(bytes).binding()); - } - - if running_info.scalars.num_f16 > 0 { - let bytes = bytemuck::cast_slice(&context.scalar_f16[0..running_info.scalars.num_f16]); - bindings.push(client.create(bytes).binding()); - } - - if running_info.scalars.num_bf16 > 0 { - let bytes = - bytemuck::cast_slice(&context.scalar_bf16[0..running_info.scalars.num_bf16]); - bindings.push(client.create(bytes).binding()); - } - - if running_info.scalars.num_int > 0 { - bindings.push( - client - .create(bytemuck::cast_slice( - &context.scalar_ints[0..running_info.scalars.num_int], - )) - .binding(), - ); - } - - // We have to register the output handles to the context. - for (id, handle) in output_register { - context.handles.register_handle(id, handle); - } - - let cube_count = fusion_kernel.cube_count.clone(); - ExecutableKernel::new( - Box::new(KernelTask::::new(fusion_kernel)), - cube_count, - bindings, - client, - ) - } -} - -impl Kernel for FusionKernel { - fn define(&self) -> KernelDefinition { - log::info!("Compiling ... {:?}", self.id()); - KernelIntegrator::new(self.info.as_ref().clone()).integrate(self.settings.clone()) - } - - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::().info((self.settings.clone(), self.id)) - } -} - -fn register_info_tensor( - info: &mut Vec, - tensor: &TensorDescription, - handle: &JitFusionHandle, -) { - if info.is_empty() { - info.push(handle.strides.len() as u32); - } - - for s in handle.strides.iter() { - info.push(*s as u32); - } - for s in tensor.shape.iter() { - info.push(*s as u32); - } -} - -fn process_inputs_outputs<'a, R>( - inputs: &[&TensorDescription], - outputs: &[&TensorDescription], - context: &'a mut Context<'_, JitFusionHandle>, - stateful: bool, -) -> ( - Vec>, - Vec<&'a TensorDescription>, - Vec<&'a TensorDescription>, -) -where - R: JitRuntime, -{ - let mut inputs_description_updated = Vec::with_capacity(inputs.len()); - let mut outputs_description_updated = Vec::with_capacity(outputs.len()); - let mut handles_input = Vec::new(); - - for tensor in inputs.iter() { - let status = if stateful { - &tensor.status // Important to take the status of the relative graph and not - // the global graph, since the status of the global graph - // might be of a later operation on the same tensor id. - } else { - &TensorStatus::ReadOnly - }; - - let tensor = context.tensors.get(&tensor.id).unwrap(); - let handle = context.handles.get_handle(&tensor.id, status); - - handles_input.push(handle); - inputs_description_updated.push(tensor); - } - - for tensor in outputs.iter() { - let tensor = context.tensors.get(&tensor.id).unwrap(); - outputs_description_updated.push(tensor); - } - - ( - handles_input, - inputs_description_updated, - outputs_description_updated, - ) -} diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index be8208571..bd1fe3ddf 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -31,6 +31,24 @@ pub fn read( }; Line::cast_from(tensor[offset]) } + ElemwisePrecision::BF16 => { + let tensor = inputs.t_bf16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U64 => { + let tensor = inputs.t_u64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } ElemwisePrecision::U32 => { let tensor = inputs.t_u32.index(pos); let offset = match layout { @@ -40,6 +58,33 @@ pub fn read( }; Line::cast_from(tensor[offset]) } + ElemwisePrecision::U16 => { + let tensor = inputs.t_u16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U8 => { + let tensor = inputs.t_u8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I64 => { + let tensor = inputs.t_i64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } ElemwisePrecision::I32 => { let tensor = inputs.t_i32.index(pos); let offset = match layout { @@ -49,6 +94,24 @@ pub fn read( }; Line::cast_from(tensor[offset]) } + ElemwisePrecision::I16 => { + let tensor = inputs.t_i16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I8 => { + let tensor = inputs.t_i8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Output(pos, precision, layout) => match comptime![precision] { @@ -70,6 +133,24 @@ pub fn read( }; Line::cast_from(tensor[offset]) } + ElemwisePrecision::BF16 => { + let tensor = outputs.t_bf16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U64 => { + let tensor = outputs.t_u64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } ElemwisePrecision::U32 => { let tensor = outputs.t_u32.index(pos); let offset = match layout { @@ -79,6 +160,33 @@ pub fn read( }; Line::cast_from(tensor[offset]) } + ElemwisePrecision::U16 => { + let tensor = outputs.t_u16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U8 => { + let tensor = outputs.t_u8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I64 => { + let tensor = outputs.t_i64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } ElemwisePrecision::I32 => { let tensor = outputs.t_i32.index(pos); let offset = match layout { @@ -88,22 +196,52 @@ pub fn read( }; Line::cast_from(tensor[offset]) } + ElemwisePrecision::I16 => { + let tensor = outputs.t_i16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I8 => { + let tensor = outputs.t_i8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Local(pos, precision) => match comptime![precision] { ElemwisePrecision::F32 => Line::cast_from(locals.l_f32.find(pos)), ElemwisePrecision::F16 => Line::cast_from(locals.l_f16.find(pos)), + ElemwisePrecision::BF16 => Line::cast_from(locals.l_bf16.find(pos)), + ElemwisePrecision::U64 => Line::cast_from(locals.l_u64.find(pos)), ElemwisePrecision::U32 => Line::cast_from(locals.l_u32.find(pos)), + ElemwisePrecision::U16 => Line::cast_from(locals.l_u16.find(pos)), + ElemwisePrecision::U8 => Line::cast_from(locals.l_u8.find(pos)), + ElemwisePrecision::I64 => Line::cast_from(locals.l_i64.find(pos)), ElemwisePrecision::I32 => Line::cast_from(locals.l_i32.find(pos)), + ElemwisePrecision::I16 => Line::cast_from(locals.l_i16.find(pos)), + ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)), ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)), - _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Scalar(pos, precision) => match comptime![precision] { ElemwisePrecision::F32 => Line::cast_from(*inputs.s_f32.index(pos)), ElemwisePrecision::F16 => Line::cast_from(*inputs.s_f16.index(pos)), + ElemwisePrecision::BF16 => Line::cast_from(*inputs.s_bf16.index(pos)), + ElemwisePrecision::U64 => Line::cast_from(*inputs.s_u64.index(pos)), ElemwisePrecision::U32 => Line::cast_from(*inputs.s_u32.index(pos)), + ElemwisePrecision::U16 => Line::cast_from(*inputs.s_u16.index(pos)), + ElemwisePrecision::U8 => Line::cast_from(*inputs.s_u8.index(pos)), + ElemwisePrecision::I64 => Line::cast_from(*inputs.s_i64.index(pos)), ElemwisePrecision::I32 => Line::cast_from(*inputs.s_i32.index(pos)), - ElemwisePrecision::BF16 => comptime![panic!("Can't write into inputs or scalars")], + ElemwisePrecision::I16 => Line::cast_from(*inputs.s_i16.index(pos)), + ElemwisePrecision::I8 => Line::cast_from(*inputs.s_i8.index(pos)), _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Literal(val, _precision) => Line::cast_from(val.runtime()), @@ -143,6 +281,26 @@ pub fn write( let tensor = outputs.t_f16.index_mut(pos); tensor[offset] = Line::cast_from(value); } + ElemwisePrecision::BF16 => { + let tensor = outputs.t_bf16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_bf16.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } + ElemwisePrecision::U64 => { + let tensor = outputs.t_u64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_u64.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } ElemwisePrecision::U32 => { let tensor = outputs.t_u32.index(pos); let offset = match layout { @@ -153,6 +311,36 @@ pub fn write( let tensor = outputs.t_u32.index_mut(pos); tensor[offset] = Line::cast_from(value); } + ElemwisePrecision::U16 => { + let tensor = outputs.t_u16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_u16.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } + ElemwisePrecision::U8 => { + let tensor = outputs.t_u8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_u8.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } + ElemwisePrecision::I64 => { + let tensor = outputs.t_i64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_i64.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } ElemwisePrecision::I32 => { let tensor = outputs.t_i32.index(pos); let offset = match layout { @@ -163,15 +351,41 @@ pub fn write( let tensor = outputs.t_i32.index_mut(pos); tensor[offset] = Line::cast_from(value); } + ElemwisePrecision::I16 => { + let tensor = outputs.t_i16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_i16.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } + ElemwisePrecision::I8 => { + let tensor = outputs.t_i8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + let tensor = outputs.t_i8.index_mut(pos); + tensor[offset] = Line::cast_from(value); + } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Local(pos, precision) => match comptime![precision] { ElemwisePrecision::F32 => locals.l_f32.insert(pos, Line::cast_from(value)), ElemwisePrecision::F16 => locals.l_f16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U64 => locals.l_u64.insert(pos, Line::cast_from(value)), ElemwisePrecision::U32 => locals.l_u32.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U16 => locals.l_u16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U8 => locals.l_u8.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I64 => locals.l_i64.insert(pos, Line::cast_from(value)), ElemwisePrecision::I32 => locals.l_i32.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I16 => locals.l_i16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I8 => locals.l_i8.insert(pos, Line::cast_from(value)), ElemwisePrecision::Bool => locals.l_bool.insert(pos, Line::cast_from(value)), - _ => comptime![panic!("Unsupported precision {precision:?}")], }, _ => comptime![panic!("Can't write into inputs and scalars")], } @@ -195,14 +409,42 @@ fn get_offset( let layout = inputs.t_f16.index(index); index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) } + ElemwisePrecision::BF16 => { + let layout = inputs.t_bf16.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::U64 => { + let layout = inputs.t_u64.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } ElemwisePrecision::U32 => { let layout = inputs.t_u32.index(index); index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) } + ElemwisePrecision::U16 => { + let layout = inputs.t_u16.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::U8 => { + let layout = inputs.t_u8.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::I64 => { + let layout = inputs.t_i64.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } ElemwisePrecision::I32 => { let layout = inputs.t_i32.index(index); index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) } + ElemwisePrecision::I16 => { + let layout = inputs.t_i16.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::I8 => { + let layout = inputs.t_i8.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Output(index, precision, _) => match comptime![precision] { @@ -214,14 +456,42 @@ fn get_offset( let layout = outputs.t_f16.index(index); index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) } + ElemwisePrecision::BF16 => { + let layout = outputs.t_bf16.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::U64 => { + let layout = outputs.t_u64.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } ElemwisePrecision::U32 => { let layout = outputs.t_u32.index(index); index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) } + ElemwisePrecision::U16 => { + let layout = outputs.t_u16.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::U8 => { + let layout = outputs.t_u8.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::I64 => { + let layout = outputs.t_i64.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } ElemwisePrecision::I32 => { let layout = outputs.t_i32.index(index); index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) } + ElemwisePrecision::I16 => { + let layout = outputs.t_i16.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } + ElemwisePrecision::I8 => { + let layout = outputs.t_i8.index(index); + index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + } _ => comptime![panic!("Unsupported precision {precision:?}")], }, _ => comptime![panic!("Invalid ref layout.")], diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 3198ee93e..eb99d3832 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -79,13 +79,25 @@ pub struct GlobalArgs { pub t_f32: Sequence>>, pub t_f16: Sequence>>, pub t_bf16: Sequence>>, + pub t_i64: Sequence>>, pub t_i32: Sequence>>, + pub t_i16: Sequence>>, + pub t_i8: Sequence>>, + pub t_u64: Sequence>>, pub t_u32: Sequence>>, + pub t_u16: Sequence>>, + pub t_u8: Sequence>>, pub s_f32: Sequence, pub s_f16: Sequence, pub s_bf16: Sequence, + pub s_i64: Sequence, pub s_i32: Sequence, + pub s_i16: Sequence, + pub s_i8: Sequence, + pub s_u64: Sequence, pub s_u32: Sequence, + pub s_u16: Sequence, + pub s_u8: Sequence, } #[derive(CubeType, Clone)] @@ -95,8 +107,14 @@ pub struct LocalArgs { pub l_f32: Registry>, pub l_f16: Registry>, pub l_bf16: Registry>, + pub l_i64: Registry>, pub l_i32: Registry>, + pub l_i16: Registry>, + pub l_i8: Registry>, + pub l_u64: Registry>, pub l_u32: Registry>, + pub l_u16: Registry>, + pub l_u8: Registry>, pub l_bool: Registry>, } @@ -123,9 +141,13 @@ pub enum ElemwisePrecision { F32, F16, BF16, + I64, I32, + I16, I8, + U64, U32, + U16, U8, Bool, } @@ -139,8 +161,18 @@ impl From for ElemwisePrecision { cubecl::ir::FloatKind::F32 => Self::F32, _ => panic!("Unsupported precision for fusion: {value}"), }, - Elem::Int(cubecl::ir::IntKind::I32) => Self::I32, - Elem::UInt => Self::U32, + Elem::Int(kind) => match kind { + cubecl::ir::IntKind::I64 => Self::I64, + cubecl::ir::IntKind::I32 => Self::I32, + cubecl::ir::IntKind::I16 => Self::I16, + cubecl::ir::IntKind::I8 => Self::I8, + }, + Elem::UInt(kind) => match kind { + cubecl::ir::UIntKind::U64 => Self::U64, + cubecl::ir::UIntKind::U32 => Self::U32, + cubecl::ir::UIntKind::U16 => Self::U16, + cubecl::ir::UIntKind::U8 => Self::U8, + }, Elem::Bool => Self::Bool, _ => panic!("Unsupported precision for fusion: {value}"), } @@ -153,9 +185,13 @@ impl From for ElemwisePrecision { DType::F32 => Self::F32, DType::F16 => Self::F16, DType::BF16 => Self::BF16, + DType::I64 => Self::I64, DType::I32 => Self::I32, + DType::I16 => Self::I16, DType::I8 => Self::I8, + DType::U64 => Self::U64, DType::U32 => Self::U32, + DType::U16 => Self::U16, DType::U8 => Self::U8, DType::Bool => Self::Bool, _ => panic!("Unsupported"), diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index 97c2408a5..269ba1f3b 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -19,8 +19,14 @@ pub fn fuse_on_write( l_f32: Registry::>::new(), l_f16: Registry::>::new(), l_bf16: Registry::>::new(), + l_i64: Registry::>::new(), l_i32: Registry::>::new(), + l_i16: Registry::>::new(), + l_i8: Registry::>::new(), + l_u64: Registry::>::new(), l_u32: Registry::>::new(), + l_u16: Registry::>::new(), + l_u8: Registry::>::new(), l_bool: Registry::>::new(), }; @@ -48,12 +54,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { add::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + add::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { add::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + add::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + add::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + add::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { add::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + add::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + add::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Div(op) => match op.out.precision() { @@ -66,12 +90,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { div::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + div::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { div::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + div::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + div::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + div::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { div::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + div::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + div::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Sub(op) => match op.out.precision() { @@ -84,12 +126,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { sub::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + sub::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { sub::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + sub::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + sub::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + sub::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { sub::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + sub::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + sub::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Mul(op) => match op.out.precision() { @@ -102,12 +162,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { mul::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + mul::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { mul::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + mul::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + mul::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + mul::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { mul::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + mul::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + mul::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Powf(op) => match op.out.precision() { @@ -144,12 +222,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { abs::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U64 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { assign::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I64 => { + abs::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { abs::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + abs::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + abs::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Log(op) => match op.out.precision() { @@ -198,16 +294,33 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { assign::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { assign::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { assign::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + assign::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::Bool => { assign::(inputs, outputs, &mut locals, write_pos, op, config) } - _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Exp(op) => match op.out.precision() { ElemwisePrecision::F32 => { @@ -267,12 +380,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Greater(op) => match op.lhs.precision() { @@ -285,12 +416,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { greater::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + greater::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { greater::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + greater::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + greater::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + greater::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { greater::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + greater::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + greater::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::GreaterEqual(op) => match op.lhs.precision() { @@ -303,12 +452,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::Lower(op) => match op.lhs.precision() { @@ -321,12 +488,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { lower::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + lower::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { lower::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + lower::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + lower::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + lower::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { lower::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + lower::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + lower::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::LowerEqual(op) => match op.lhs.precision() { @@ -339,12 +524,30 @@ pub fn fuse_on_write( ElemwisePrecision::BF16 => { lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I64 => { + lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::I32 => { lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::I16 => { + lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::I8 => { + lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U64 => { + lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } ElemwisePrecision::U32 => { lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) } + ElemwisePrecision::U16 => { + lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } + ElemwisePrecision::U8 => { + lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) + } _ => comptime![panic!("Unsupported precision {op:?}")], }, ElemwiseOp::ConditionalAssign { @@ -386,6 +589,17 @@ pub fn fuse_on_write( out, config, ), + ElemwisePrecision::I64 => conditional_assign::( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), ElemwisePrecision::I32 => conditional_assign::( inputs, outputs, @@ -397,6 +611,39 @@ pub fn fuse_on_write( out, config, ), + ElemwisePrecision::I16 => conditional_assign::( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), + ElemwisePrecision::I8 => conditional_assign::( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), + ElemwisePrecision::U64 => conditional_assign::( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), ElemwisePrecision::U32 => conditional_assign::( inputs, outputs, @@ -408,6 +655,28 @@ pub fn fuse_on_write( out, config, ), + ElemwisePrecision::U16 => conditional_assign::( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), + ElemwisePrecision::U8 => conditional_assign::( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), _ => comptime![panic!("Unsupported precision {op:?}")], }, } diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 87b9750e2..c7ee6bde1 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -333,6 +333,18 @@ impl FuseOnWriteTrace { SequenceArg::new(), SequenceArg::new(), SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), ); for hi in handle_inputs.iter() { @@ -341,8 +353,14 @@ impl FuseOnWriteTrace { ElemwisePrecision::F32 => inputs.t_f32.push(arg), ElemwisePrecision::F16 => inputs.t_f16.push(arg), ElemwisePrecision::BF16 => inputs.t_bf16.push(arg), + ElemwisePrecision::I64 => inputs.t_i64.push(arg), ElemwisePrecision::I32 => inputs.t_i32.push(arg), + ElemwisePrecision::I16 => inputs.t_i16.push(arg), + ElemwisePrecision::I8 => inputs.t_i8.push(arg), + ElemwisePrecision::U64 => inputs.t_u64.push(arg), ElemwisePrecision::U32 => inputs.t_u32.push(arg), + ElemwisePrecision::U16 => inputs.t_u16.push(arg), + ElemwisePrecision::U8 => inputs.t_u8.push(arg), _ => panic!("Unsupported input precision {:?}", hi.precision), }; } @@ -356,10 +374,30 @@ impl FuseOnWriteTrace { ElemwisePrecision::F16 => { inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i])) } - ElemwisePrecision::I32 => { - inputs.s_i32.push(ScalarArg::new(context.scalar_ints[i])) + ElemwisePrecision::BF16 => { + inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i])) } - _ => todo!(), + ElemwisePrecision::I64 => { + inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i])) + } + ElemwisePrecision::I32 => { + inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i])) + } + ElemwisePrecision::I16 => { + inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i])) + } + ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])), + ElemwisePrecision::U64 => { + inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i])) + } + ElemwisePrecision::U32 => { + inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i])) + } + ElemwisePrecision::U16 => { + inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i])) + } + ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])), + ElemwisePrecision::Bool => todo!(), } } } @@ -383,6 +421,18 @@ impl FuseOnWriteTrace { SequenceArg::new(), SequenceArg::new(), SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), + SequenceArg::new(), ); for item in handle_outputs.iter() { match item { @@ -392,8 +442,15 @@ impl FuseOnWriteTrace { } => match precision { ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)), ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)), ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)), ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)), _ => todo!(), }, HandleOutput::Owned { @@ -406,11 +463,17 @@ impl FuseOnWriteTrace { match precision { ElemwisePrecision::F32 => outputs.t_f32.push(arg), ElemwisePrecision::F16 => outputs.t_f16.push(arg), + ElemwisePrecision::BF16 => outputs.t_bf16.push(arg), + ElemwisePrecision::I64 => outputs.t_i64.push(arg), ElemwisePrecision::I32 => outputs.t_i32.push(arg), + ElemwisePrecision::I16 => outputs.t_i16.push(arg), + ElemwisePrecision::I8 => outputs.t_i8.push(arg), + ElemwisePrecision::U64 => outputs.t_u64.push(arg), ElemwisePrecision::U32 => outputs.t_u32.push(arg), + ElemwisePrecision::U16 => outputs.t_u16.push(arg), + ElemwisePrecision::U8 => outputs.t_u8.push(arg), // Bools are encoded as u32. ElemwisePrecision::Bool => outputs.t_u32.push(arg), - _ => todo!(), }; } } diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index 722b13812..4dc4c08e1 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -127,7 +127,7 @@ pub(crate) fn launch_binop>( let vectorization_factor_rhs = tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1); - let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs); + let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs); let mut shape_out = vec![0; ndims]; lhs.shape diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index 401ab4ff1..34d7b136d 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -119,7 +119,7 @@ pub(crate) fn launch_cmp>( let vectorization_factor_rhs = tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1); - let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs); + let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs); let mut shape_out = vec![0; ndims]; lhs.shape @@ -169,7 +169,7 @@ pub(crate) fn launch_cmp>( JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides) } else { - let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); + let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; @@ -225,7 +225,7 @@ pub(crate) fn launch_scalar_cmp tensor.strides, ) } else { - let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); + let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); let output = JitTensor::new( tensor.client.clone(), buffer, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 5c6dc835a..456006cbf 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -60,7 +60,7 @@ pub fn conv2d_implicit_gemm( options.groups, out_h, out_w, - &input.device, + &input.client, ) { panic!( "Requirements for implicit GEMM not met: @@ -84,7 +84,7 @@ pub fn conv2d_implicit_gemm( let slice_size = pad_kh * pad_kw * pad_in_channels; let (cmma_m, cmma_n, cmma_k) = - find_cmma_size::(&input.device, gemm_m, gemm_k, gemm_n).unwrap(); + find_cmma_size::(&input.client, gemm_m, gemm_k, gemm_n).unwrap(); let cube_dim_x = 128; let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2); @@ -187,7 +187,7 @@ pub fn conv2d_implicit_gemm( fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 { let channels = channels as u8; let elems_per_thread = elems_per_thread as u8; - let smaller = u8::min(channels, elems_per_thread); + let smaller = Ord::min(channels, elems_per_thread); (1..=smaller) .rev() .filter(|it| supported_vecs.contains(it)) @@ -663,7 +663,7 @@ pub(crate) fn can_do_implicit_gemm( groups: usize, out_h: usize, out_w: usize, - device: &R::Device, + client: &ComputeClient, ) -> bool { let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_size[0], kernel_size[1]); let batch_size = padded_batch_size(batch_size, out_h, out_w); @@ -673,7 +673,7 @@ pub(crate) fn can_do_implicit_gemm( let gemm_n = out_channels; let gemm_k = in_channels * kernel_h * kernel_w; - let size = find_cmma_size::(device, gemm_m as u32, gemm_k as u32, gemm_n as u32); + let size = find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32); if let Some((cmma_m, cmma_k, cmma_n)) = size { let warps_per_cube = 8; @@ -716,12 +716,12 @@ fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize { } fn find_cmma_size( - device: &R::JitDevice, + client: &ComputeClient, gemm_m: u32, gemm_k: u32, gemm_n: u32, ) -> Option<(u32, u32, u32)> { - supported_cmma_sizes::(device) + supported_cmma_sizes::(client) .into_iter() .find(|(m, k, n)| { gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0 @@ -730,7 +730,7 @@ fn find_cmma_size( } fn supported_cmma_sizes( - device: &R::JitDevice, + client: &ComputeClient, ) -> Vec<(u8, u8, u8)> { let requested_sizes = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]; @@ -738,16 +738,14 @@ fn supported_cmma_sizes( .iter() .copied() .filter(|(m, k, n)| { - R::client(device) - .properties() - .feature_enabled(Feature::Cmma { - a: F::as_elem(), - b: F::as_elem(), - c: FAcc::as_elem(), - m: *m, - k: *k, - n: *n, - }) + client.properties().feature_enabled(Feature::Cmma { + a: F::as_elem(), + b: F::as_elem(), + c: FAcc::as_elem(), + m: *m, + k: *k, + n: *n, + }) }) .collect() } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 4a03a5a83..07b90bdde 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -3,6 +3,7 @@ use cubecl::{ ir::{ Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, }, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -43,14 +44,14 @@ impl Conv2dTransposeComputeShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let input_stride_0 = scope.create_local(Elem::UInt); - let input_stride_1 = scope.create_local(Elem::UInt); - let input_stride_2 = scope.create_local(Elem::UInt); - let input_stride_3 = scope.create_local(Elem::UInt); - let input_shape_0 = scope.create_local(Elem::UInt); - let input_shape_1 = scope.create_local(Elem::UInt); - let input_shape_2 = scope.create_local(Elem::UInt); - let input_shape_3 = scope.create_local(Elem::UInt); + let input_stride_0 = scope.create_local(u32::as_elem()); + let input_stride_1 = scope.create_local(u32::as_elem()); + let input_stride_2 = scope.create_local(u32::as_elem()); + let input_stride_3 = scope.create_local(u32::as_elem()); + let input_shape_0 = scope.create_local(u32::as_elem()); + let input_shape_1 = scope.create_local(u32::as_elem()); + let input_shape_2 = scope.create_local(u32::as_elem()); + let input_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, input_stride_0 = stride(input, 0u32)); cpa!(scope, input_stride_1 = stride(input, 1u32)); cpa!(scope, input_stride_2 = stride(input, 2u32)); @@ -60,14 +61,14 @@ impl Conv2dTransposeComputeShader { cpa!(scope, input_shape_2 = shape(input, 2u32)); cpa!(scope, input_shape_3 = shape(input, 3u32)); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, output_stride_0 = stride(output, 0u32)); cpa!(scope, output_stride_1 = stride(output, 1u32)); cpa!(scope, output_stride_2 = stride(output, 2u32)); @@ -77,14 +78,14 @@ impl Conv2dTransposeComputeShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let weight_stride_0 = scope.create_local(Elem::UInt); - let weight_stride_1 = scope.create_local(Elem::UInt); - let weight_stride_2 = scope.create_local(Elem::UInt); - let weight_stride_3 = scope.create_local(Elem::UInt); - let in_channels = scope.create_local(Elem::UInt); - let weight_shape_1 = scope.create_local(Elem::UInt); - let kernel_size_0 = scope.create_local(Elem::UInt); - let kernel_size_1 = scope.create_local(Elem::UInt); + let weight_stride_0 = scope.create_local(u32::as_elem()); + let weight_stride_1 = scope.create_local(u32::as_elem()); + let weight_stride_2 = scope.create_local(u32::as_elem()); + let weight_stride_3 = scope.create_local(u32::as_elem()); + let in_channels = scope.create_local(u32::as_elem()); + let weight_shape_1 = scope.create_local(u32::as_elem()); + let kernel_size_0 = scope.create_local(u32::as_elem()); + let kernel_size_1 = scope.create_local(u32::as_elem()); cpa!(scope, weight_stride_0 = stride(weight, 0u32)); cpa!(scope, weight_stride_1 = stride(weight, 1u32)); cpa!(scope, weight_stride_2 = stride(weight, 2u32)); @@ -94,31 +95,31 @@ impl Conv2dTransposeComputeShader { cpa!(scope, kernel_size_0 = shape(weight, 2u32)); cpa!(scope, kernel_size_1 = shape(weight, 3u32)); - let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); - let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); - let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); - let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); - let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); - let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); - let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt)); + let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem())); + let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem())); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem())); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem())); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem())); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem())); + let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(u32::as_elem())); let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); cpa!(scope, stride_0_i = cast(conv_stride_0)); cpa!(scope, stride_1_i = cast(conv_stride_1)); - let oc_out = scope.create_local(Elem::UInt); - let oc = scope.create_local(Elem::UInt); + let oc_out = scope.create_local(u32::as_elem()); + let oc = scope.create_local(u32::as_elem()); - let b = scope.create_local(Elem::UInt); - let oh = scope.create_local(Elem::UInt); - let ow = scope.create_local(Elem::UInt); - let k = scope.create_local(Elem::UInt); - let g = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let oh = scope.create_local(u32::as_elem()); + let ow = scope.create_local(u32::as_elem()); + let k = scope.create_local(u32::as_elem()); + let g = scope.create_local(u32::as_elem()); - let ic_start = scope.create_local(Elem::UInt); - let ic_end = scope.create_local(Elem::UInt); - let ic_tmp = scope.create_local(Elem::UInt); + let ic_start = scope.create_local(u32::as_elem()); + let ic_end = scope.create_local(u32::as_elem()); + let ic_tmp = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -141,20 +142,20 @@ impl Conv2dTransposeComputeShader { cpa!(scope, ic_start = g * ic_tmp); cpa!(scope, ic_end = ic_start + ic_tmp); - let tmp_u = scope.create_local(Elem::UInt); + let tmp_u = scope.create_local(u32::as_elem()); let tmp_i = scope.create_local(Elem::Int(IntKind::I32)); let zero_i = scope.zero(Elem::Int(IntKind::I32)); let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32)); - let kms_u = scope.create_local(Elem::UInt); + let kms_u = scope.create_local(u32::as_elem()); let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); - let ih_start = scope.create_local(Elem::UInt); - let iw_start = scope.create_local(Elem::UInt); - let ih_end = scope.create_local(Elem::UInt); - let iw_end = scope.create_local(Elem::UInt); + let ih_start = scope.create_local(u32::as_elem()); + let iw_start = scope.create_local(u32::as_elem()); + let ih_end = scope.create_local(u32::as_elem()); + let iw_end = scope.create_local(u32::as_elem()); cpa!(scope, kms_u = kernel_size_0 * dilation_0); cpa!(scope, kms_0 = cast(kms_u)); @@ -188,17 +189,17 @@ impl Conv2dTransposeComputeShader { cpa!(scope, tmp_u = cast(tmp_i)); cpa!(scope, iw_end = min(tmp_u, input_shape_3)); - let index_input = scope.create_local(Elem::UInt); - let index_weight = scope.create_local(Elem::UInt); + let index_input = scope.create_local(u32::as_elem()); + let index_weight = scope.create_local(u32::as_elem()); - let index_input_b = scope.create_local(Elem::UInt); - let index_input_ic = scope.create_local(Elem::UInt); - let index_input_ih = scope.create_local(Elem::UInt); - let index_input_iw = scope.create_local(Elem::UInt); - let index_weight_ic = scope.create_local(Elem::UInt); - let index_weight_oc = scope.create_local(Elem::UInt); - let index_weight_kh = scope.create_local(Elem::UInt); - let index_weight_kw = scope.create_local(Elem::UInt); + let index_input_b = scope.create_local(u32::as_elem()); + let index_input_ic = scope.create_local(u32::as_elem()); + let index_input_ih = scope.create_local(u32::as_elem()); + let index_input_iw = scope.create_local(u32::as_elem()); + let index_weight_ic = scope.create_local(u32::as_elem()); + let index_weight_oc = scope.create_local(u32::as_elem()); + let index_weight_kh = scope.create_local(u32::as_elem()); + let index_weight_kw = scope.create_local(u32::as_elem()); cpa!(scope, index_input_b = b * input_stride_0); cpa!(scope, index_weight_oc = oc * weight_stride_1); @@ -208,15 +209,15 @@ impl Conv2dTransposeComputeShader { let sum = scope.create_local(output.item); cpa!(scope, sum = bias[oc_out]); - let kh = scope.create_local(Elem::UInt); - let kw = scope.create_local(Elem::UInt); - let numerator_h_base = scope.create_local(Elem::UInt); - let numerator_h = scope.create_local(Elem::UInt); - let numerator_w_base = scope.create_local(Elem::UInt); - let numerator_w = scope.create_local(Elem::UInt); - let numerator_tmp = scope.create_local(Elem::UInt); - let numerator_mod = scope.create_local(Elem::UInt); - let zero = scope.zero(Elem::UInt); + let kh = scope.create_local(u32::as_elem()); + let kw = scope.create_local(u32::as_elem()); + let numerator_h_base = scope.create_local(u32::as_elem()); + let numerator_h = scope.create_local(u32::as_elem()); + let numerator_w_base = scope.create_local(u32::as_elem()); + let numerator_w = scope.create_local(u32::as_elem()); + let numerator_tmp = scope.create_local(u32::as_elem()); + let numerator_mod = scope.create_local(u32::as_elem()); + let zero = scope.zero(u32::as_elem()); let divisible = scope.create_local(Elem::Bool); let not_neg = scope.create_local(Elem::Bool); let cond = scope.create_local(Elem::Bool); @@ -324,7 +325,7 @@ impl Kernel for Conv2dTransposeEagerKernel { visibility: Visibility::Read, }; let scalars = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: 7, }; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index f86390165..105b0f1f2 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -9,10 +9,7 @@ use cubecl::{ use crate::{ kernel::{ - conv::{ - batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col, - conv2d_implicit_gemm, - }, + conv::{batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col}, prng::random_uniform, }, tensor::JitTensor, @@ -42,7 +39,7 @@ pub fn conv2d_autotune( } #[tune( - operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm), + operations(conv2d_direct, conv2d_im2col), create_key = create_key, should_run = should_run )] @@ -111,7 +108,7 @@ fn should_run( op.options.groups, out_h, out_w, - &op.input.device, + &op.input.client, ), _ => true, } @@ -143,5 +140,6 @@ fn create_key( width, batch_size, bias.is_some(), + E::dtype(), )) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index 8b15808f2..aa0f0972a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -91,6 +91,7 @@ fn create_key( width, batch_size, bias.is_some(), + E::dtype(), )) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/key.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/key.rs index 224eef450..829ef9067 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/key.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/key.rs @@ -1,3 +1,4 @@ +use burn_tensor::DType; use cubecl::AutotuneKey; use serde::{Deserialize, Serialize}; @@ -20,6 +21,7 @@ pub struct Conv2dAutotuneKey { #[autotune(anchor)] pub batch_size: usize, pub has_bias: bool, + pub dtype: DType, } #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] @@ -42,4 +44,5 @@ pub struct ConvTranspose2dAutotuneKey { #[autotune(anchor)] pub batch_size: usize, pub has_bias: bool, + pub dtype: DType, } diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs index 0e7a53812..283346dd6 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs @@ -3,6 +3,7 @@ use cubecl::{ ir::{ Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, }, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -43,16 +44,16 @@ impl Conv3dTransposeComputeShader { let output = self.output; let idx = Variable::builtin(Builtin::AbsolutePos); - let input_stride_0 = scope.create_local(Elem::UInt); - let input_stride_1 = scope.create_local(Elem::UInt); - let input_stride_2 = scope.create_local(Elem::UInt); - let input_stride_3 = scope.create_local(Elem::UInt); - let input_stride_4 = scope.create_local(Elem::UInt); - let input_shape_0 = scope.create_local(Elem::UInt); - let input_shape_1 = scope.create_local(Elem::UInt); - let input_shape_2 = scope.create_local(Elem::UInt); - let input_shape_3 = scope.create_local(Elem::UInt); - let input_shape_4 = scope.create_local(Elem::UInt); + let input_stride_0 = scope.create_local(u32::as_elem()); + let input_stride_1 = scope.create_local(u32::as_elem()); + let input_stride_2 = scope.create_local(u32::as_elem()); + let input_stride_3 = scope.create_local(u32::as_elem()); + let input_stride_4 = scope.create_local(u32::as_elem()); + let input_shape_0 = scope.create_local(u32::as_elem()); + let input_shape_1 = scope.create_local(u32::as_elem()); + let input_shape_2 = scope.create_local(u32::as_elem()); + let input_shape_3 = scope.create_local(u32::as_elem()); + let input_shape_4 = scope.create_local(u32::as_elem()); cpa!(scope, input_stride_0 = stride(input, 0u32)); cpa!(scope, input_stride_1 = stride(input, 1u32)); cpa!(scope, input_stride_2 = stride(input, 2u32)); @@ -64,16 +65,16 @@ impl Conv3dTransposeComputeShader { cpa!(scope, input_shape_3 = shape(input, 3u32)); cpa!(scope, input_shape_4 = shape(input, 4u32)); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); - let output_stride_4 = scope.create_local(Elem::UInt); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); - let output_shape_4 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); + let output_stride_4 = scope.create_local(u32::as_elem()); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); + let output_shape_4 = scope.create_local(u32::as_elem()); cpa!(scope, output_stride_0 = stride(output, 0u32)); cpa!(scope, output_stride_1 = stride(output, 1u32)); cpa!(scope, output_stride_2 = stride(output, 2u32)); @@ -85,16 +86,16 @@ impl Conv3dTransposeComputeShader { cpa!(scope, output_shape_3 = shape(output, 3u32)); cpa!(scope, output_shape_4 = shape(output, 4u32)); - let weight_stride_0 = scope.create_local(Elem::UInt); - let weight_stride_1 = scope.create_local(Elem::UInt); - let weight_stride_2 = scope.create_local(Elem::UInt); - let weight_stride_3 = scope.create_local(Elem::UInt); - let weight_stride_4 = scope.create_local(Elem::UInt); - let in_channels = scope.create_local(Elem::UInt); - let weight_shape_1 = scope.create_local(Elem::UInt); - let kernel_size_0 = scope.create_local(Elem::UInt); - let kernel_size_1 = scope.create_local(Elem::UInt); - let kernel_size_2 = scope.create_local(Elem::UInt); + let weight_stride_0 = scope.create_local(u32::as_elem()); + let weight_stride_1 = scope.create_local(u32::as_elem()); + let weight_stride_2 = scope.create_local(u32::as_elem()); + let weight_stride_3 = scope.create_local(u32::as_elem()); + let weight_stride_4 = scope.create_local(u32::as_elem()); + let in_channels = scope.create_local(u32::as_elem()); + let weight_shape_1 = scope.create_local(u32::as_elem()); + let kernel_size_0 = scope.create_local(u32::as_elem()); + let kernel_size_1 = scope.create_local(u32::as_elem()); + let kernel_size_2 = scope.create_local(u32::as_elem()); cpa!(scope, weight_stride_0 = stride(weight, 0u32)); cpa!(scope, weight_stride_1 = stride(weight, 1u32)); cpa!(scope, weight_stride_2 = stride(weight, 2u32)); @@ -106,16 +107,16 @@ impl Conv3dTransposeComputeShader { cpa!(scope, kernel_size_1 = shape(weight, 3u32)); cpa!(scope, kernel_size_2 = shape(weight, 4u32)); - let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); - let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); - let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); - let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); - let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); - let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); - let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt)); - let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(Elem::UInt)); - let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(Elem::UInt)); - let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(Elem::UInt)); + let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem())); + let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem())); + let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem())); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem())); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem())); + let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem())); + let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(u32::as_elem())); + let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(u32::as_elem())); + let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(u32::as_elem())); + let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(u32::as_elem())); let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); @@ -124,19 +125,19 @@ impl Conv3dTransposeComputeShader { cpa!(scope, stride_1_i = cast(conv_stride_1)); cpa!(scope, stride_2_i = cast(conv_stride_2)); - let oc_out = scope.create_local(Elem::UInt); - let oc = scope.create_local(Elem::UInt); + let oc_out = scope.create_local(u32::as_elem()); + let oc = scope.create_local(u32::as_elem()); - let b = scope.create_local(Elem::UInt); - let od = scope.create_local(Elem::UInt); - let oh = scope.create_local(Elem::UInt); - let ow = scope.create_local(Elem::UInt); - let k = scope.create_local(Elem::UInt); - let g = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let od = scope.create_local(u32::as_elem()); + let oh = scope.create_local(u32::as_elem()); + let ow = scope.create_local(u32::as_elem()); + let k = scope.create_local(u32::as_elem()); + let g = scope.create_local(u32::as_elem()); - let ic_start = scope.create_local(Elem::UInt); - let ic_end = scope.create_local(Elem::UInt); - let ic_tmp = scope.create_local(Elem::UInt); + let ic_start = scope.create_local(u32::as_elem()); + let ic_end = scope.create_local(u32::as_elem()); + let ic_tmp = scope.create_local(u32::as_elem()); cpa!(scope, b = idx / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -162,24 +163,24 @@ impl Conv3dTransposeComputeShader { cpa!(scope, ic_start = g * ic_tmp); cpa!(scope, ic_end = ic_start + ic_tmp); - let tmp_u = scope.create_local(Elem::UInt); + let tmp_u = scope.create_local(u32::as_elem()); let tmp_i = scope.create_local(Elem::Int(IntKind::I32)); let zero_i = scope.zero(Elem::Int(IntKind::I32)); let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32)); - let kms_u = scope.create_local(Elem::UInt); + let kms_u = scope.create_local(u32::as_elem()); let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); let kms_2 = scope.create_local(Elem::Int(IntKind::I32)); let id_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); - let id_start = scope.create_local(Elem::UInt); - let ih_start = scope.create_local(Elem::UInt); - let iw_start = scope.create_local(Elem::UInt); - let id_end = scope.create_local(Elem::UInt); - let ih_end = scope.create_local(Elem::UInt); - let iw_end = scope.create_local(Elem::UInt); + let id_start = scope.create_local(u32::as_elem()); + let ih_start = scope.create_local(u32::as_elem()); + let iw_start = scope.create_local(u32::as_elem()); + let id_end = scope.create_local(u32::as_elem()); + let ih_end = scope.create_local(u32::as_elem()); + let iw_end = scope.create_local(u32::as_elem()); cpa!(scope, kms_u = kernel_size_0 * dilation_0); cpa!(scope, kms_0 = cast(kms_u)); @@ -228,19 +229,19 @@ impl Conv3dTransposeComputeShader { cpa!(scope, tmp_u = cast(tmp_i)); cpa!(scope, iw_end = min(tmp_u, input_shape_4)); - let index_input = scope.create_local(Elem::UInt); - let index_weight = scope.create_local(Elem::UInt); + let index_input = scope.create_local(u32::as_elem()); + let index_weight = scope.create_local(u32::as_elem()); - let index_input_b = scope.create_local(Elem::UInt); - let index_input_ic = scope.create_local(Elem::UInt); - let index_input_id = scope.create_local(Elem::UInt); - let index_input_ih = scope.create_local(Elem::UInt); - let index_input_iw = scope.create_local(Elem::UInt); - let index_weight_ic = scope.create_local(Elem::UInt); - let index_weight_oc = scope.create_local(Elem::UInt); - let index_weight_kd = scope.create_local(Elem::UInt); - let index_weight_kh = scope.create_local(Elem::UInt); - let index_weight_kw = scope.create_local(Elem::UInt); + let index_input_b = scope.create_local(u32::as_elem()); + let index_input_ic = scope.create_local(u32::as_elem()); + let index_input_id = scope.create_local(u32::as_elem()); + let index_input_ih = scope.create_local(u32::as_elem()); + let index_input_iw = scope.create_local(u32::as_elem()); + let index_weight_ic = scope.create_local(u32::as_elem()); + let index_weight_oc = scope.create_local(u32::as_elem()); + let index_weight_kd = scope.create_local(u32::as_elem()); + let index_weight_kh = scope.create_local(u32::as_elem()); + let index_weight_kw = scope.create_local(u32::as_elem()); cpa!(scope, index_input_b = b * input_stride_0); cpa!(scope, index_weight_oc = oc * weight_stride_1); @@ -250,18 +251,18 @@ impl Conv3dTransposeComputeShader { let sum = scope.create_local(output.item); cpa!(scope, sum = bias[oc_out]); - let kd = scope.create_local(Elem::UInt); - let kh = scope.create_local(Elem::UInt); - let kw = scope.create_local(Elem::UInt); - let numerator_d_base = scope.create_local(Elem::UInt); - let numerator_d = scope.create_local(Elem::UInt); - let numerator_h_base = scope.create_local(Elem::UInt); - let numerator_h = scope.create_local(Elem::UInt); - let numerator_w_base = scope.create_local(Elem::UInt); - let numerator_w = scope.create_local(Elem::UInt); - let numerator_tmp = scope.create_local(Elem::UInt); - let numerator_mod = scope.create_local(Elem::UInt); - let zero = scope.zero(Elem::UInt); + let kd = scope.create_local(u32::as_elem()); + let kh = scope.create_local(u32::as_elem()); + let kw = scope.create_local(u32::as_elem()); + let numerator_d_base = scope.create_local(u32::as_elem()); + let numerator_d = scope.create_local(u32::as_elem()); + let numerator_h_base = scope.create_local(u32::as_elem()); + let numerator_h = scope.create_local(u32::as_elem()); + let numerator_w_base = scope.create_local(u32::as_elem()); + let numerator_w = scope.create_local(u32::as_elem()); + let numerator_tmp = scope.create_local(u32::as_elem()); + let numerator_mod = scope.create_local(u32::as_elem()); + let zero = scope.zero(u32::as_elem()); let divisible = scope.create_local(Elem::Bool); let not_neg = scope.create_local(Elem::Bool); let cond = scope.create_local(Elem::Bool); @@ -392,7 +393,7 @@ impl Kernel for Conv3dTransposeEagerKernel { visibility: Visibility::Read, }; let scalars = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: 10, }; diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 6673b9648..013181a3e 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -5,7 +5,7 @@ use burn_tensor::{ use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch}; use crate::{ - kernel::into_contiguous, + kernel::{cast, into_contiguous}, ops::{ numeric::{empty_device, ones_device, zeros_device}, reshape, swap_dims, @@ -426,7 +426,8 @@ fn compute_input_grad( let [batch_size, in_channels, height, width] = input_shape.dims(); let (kernel_height, kernel_width) = kernel_dims; - let grad_in = zeros_device::( + // Force `f32` to enable bitcasting as `u32` + let grad_in = zeros_device::( client.clone(), device.clone(), Shape::new([batch_size, in_channels, height, width]), @@ -466,7 +467,7 @@ fn compute_input_grad( use_mask, ); - grad_in + cast(grad_in) } #[derive(CubeLaunch)] @@ -564,19 +565,19 @@ fn deform_col2img_kernel( let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS]; - float_atomic_add::(&mut grad_input[gradient_pos], value); + float_atomic_add(&mut grad_input[gradient_pos], f32::cast_from(value)); } } } } #[cube] -fn float_atomic_add(ptr: &mut AtomicU32, value: F) { - if value != F::new(0.0) { +fn float_atomic_add(ptr: &mut AtomicU32, value: f32) { + if value != 0.0 { let mut v = AtomicU32::load(ptr); loop { let prev = v; - let v_float = F::bitcast_from(v); + let v_float = f32::bitcast_from(v); let new = u32::bitcast_from(v_float + value); v = AtomicU32::compare_and_swap(ptr, v, new); if prev == v { diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 6e6037136..44d4281f4 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -5,6 +5,7 @@ use burn_tensor::ElementConversion; use cubecl::{ cpa, ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -29,12 +30,12 @@ impl FlipComputeShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let offset_input = scope.zero(Elem::UInt); - let offset_local = scope.create_local(Elem::UInt); + let offset_input = scope.zero(u32::as_elem()); + let offset_local = scope.create_local(u32::as_elem()); - let stride = scope.create_local(Elem::UInt); - let shape = scope.create_local(Elem::UInt); - let flip = scope.create_local(Elem::UInt); + let stride = scope.create_local(u32::as_elem()); + let shape = scope.create_local(u32::as_elem()); + let flip = scope.create_local(u32::as_elem()); let flip_bool = scope.create_local(Elem::Bool); for i in 0..self.rank { @@ -44,7 +45,7 @@ impl FlipComputeShader { scope, flip = cast(Variable::new( VariableKind::GlobalScalar(i as u16), - Item::new(Elem::UInt) + Item::new(u32::as_elem()) )) ); cpa!(scope, flip_bool = flip == 1u32); @@ -89,7 +90,7 @@ impl Kernel for FlipEagerKernel { visibility: Visibility::Read, }; let flip_dims = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: self.rank, }; let output = OutputInfo::Array { item }; diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index d071f8dd3..33f92b377 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -1,7 +1,8 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; use cubecl::{ cpa, - ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::CubePrimitive, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -28,12 +29,12 @@ impl RepeatComputeShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let offset_input = scope.zero(Elem::UInt); - let offset_local = scope.zero(Elem::UInt); + let offset_input = scope.zero(u32::as_elem()); + let offset_local = scope.zero(u32::as_elem()); - let stride_input = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape = scope.create_local(Elem::UInt); + let stride_input = scope.create_local(u32::as_elem()); + let stride_output = scope.create_local(u32::as_elem()); + let shape = scope.create_local(u32::as_elem()); for i in 0..self.rank { cpa!(scope, stride_input = stride(input, i)); diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index ff146ec2a..a6dd72c02 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -2,15 +2,15 @@ use crate::{ element::JitElement, kernel::{self}, tensor::JitTensor, - JitRuntime, + IntElement, JitRuntime, }; use cubecl::prelude::*; use cubecl::{calculate_cube_count_elemwise, CubeDim}; #[cube(launch_unchecked)] -fn scatter_kernel( +fn scatter_kernel( input: &mut Tensor, - indices: &Tensor, + indices: &Tensor, value: &Tensor, dim: &u32, ) { @@ -65,7 +65,7 @@ fn scatter_kernel( } } -pub(crate) fn scatter( +pub(crate) fn scatter( dim: usize, tensor: JitTensor, indices: JitTensor, @@ -105,7 +105,7 @@ pub(crate) fn scatter( let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); unsafe { - scatter_kernel::launch_unchecked::( + scatter_kernel::launch_unchecked::( &indices.client.clone(), cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index b2093a449..94bfddb74 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -4,7 +4,8 @@ use crate::{ use burn_tensor::{ElementConversion, Shape}; use cubecl::{ cpa, - ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + ir::{Builtin, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -29,13 +30,13 @@ impl SliceComputeShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let offset_input = scope.zero(Elem::UInt); - let offset_local = scope.create_local(Elem::UInt); + let offset_input = scope.zero(u32::as_elem()); + let offset_local = scope.create_local(u32::as_elem()); - let stride_input = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_output = scope.create_local(Elem::UInt); - let range_start = scope.create_local(Elem::UInt); + let stride_input = scope.create_local(u32::as_elem()); + let stride_output = scope.create_local(u32::as_elem()); + let shape_output = scope.create_local(u32::as_elem()); + let range_start = scope.create_local(u32::as_elem()); for i in 0..self.rank { cpa!(scope, stride_input = stride(input, i)); @@ -45,7 +46,7 @@ impl SliceComputeShader { scope, range_start = cast(Variable::new( VariableKind::GlobalScalar(i as u16), - Item::new(Elem::UInt) + Item::new(u32::as_elem()) )) ); @@ -85,7 +86,7 @@ impl Kernel for SliceEagerKernel { visibility: Visibility::Read, }; let ranges = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: self.rank, }; let output = OutputInfo::Array { item }; diff --git a/crates/burn-jit/src/kernel/index/slice_assign.rs b/crates/burn-jit/src/kernel/index/slice_assign.rs index fbfb27070..9694b4419 100644 --- a/crates/burn-jit/src/kernel/index/slice_assign.rs +++ b/crates/burn-jit/src/kernel/index/slice_assign.rs @@ -2,7 +2,8 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; use burn_tensor::ElementConversion; use cubecl::{ cpa, - ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + ir::{Builtin, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, }; use std::{marker::PhantomData, ops::Range}; @@ -26,18 +27,18 @@ impl SliceAssignComputeShader { let value = self.value; let id = Variable::builtin(Builtin::AbsolutePos); - let offset_input = scope.zero(Elem::UInt); - let offset_value = scope.zero(Elem::UInt); + let offset_input = scope.zero(u32::as_elem()); + let offset_value = scope.zero(u32::as_elem()); - let offset_local = scope.create_local(Elem::UInt); - let offset_local_value = scope.create_local(Elem::UInt); - let offset_local_input = scope.create_local(Elem::UInt); + let offset_local = scope.create_local(u32::as_elem()); + let offset_local_value = scope.create_local(u32::as_elem()); + let offset_local_input = scope.create_local(u32::as_elem()); - let stride_input = scope.create_local(Elem::UInt); - let stride_value = scope.create_local(Elem::UInt); - let shape_value = scope.create_local(Elem::UInt); - let shape_input = scope.create_local(Elem::UInt); - let range_start = scope.create_local(Elem::UInt); + let stride_input = scope.create_local(u32::as_elem()); + let stride_value = scope.create_local(u32::as_elem()); + let shape_value = scope.create_local(u32::as_elem()); + let shape_input = scope.create_local(u32::as_elem()); + let range_start = scope.create_local(u32::as_elem()); for i in 0..self.rank { cpa!(scope, stride_input = stride(input, i)); @@ -48,7 +49,7 @@ impl SliceAssignComputeShader { scope, range_start = cast(Variable::new( VariableKind::GlobalScalar(i as u16), - Item::new(Elem::UInt) + Item::new(u32::as_elem()) )) ); @@ -98,7 +99,7 @@ impl Kernel for SliceAssignEagerKernel { visibility: Visibility::Read, }; let ranges = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: self.rank, }; diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index aaf202cab..8124731dc 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -1,6 +1,7 @@ use cubecl::{ cpa, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -27,23 +28,23 @@ impl InterpolateBicubicShader { let id = Variable::builtin(Builtin::AbsolutePos); let elem = E::cube_elem(); - let input_stride_0 = scope.create_local(Elem::UInt); - let input_stride_1 = scope.create_local(Elem::UInt); - let input_stride_2 = scope.create_local(Elem::UInt); - let input_stride_3 = scope.create_local(Elem::UInt); + let input_stride_0 = scope.create_local(u32::as_elem()); + let input_stride_1 = scope.create_local(u32::as_elem()); + let input_stride_2 = scope.create_local(u32::as_elem()); + let input_stride_3 = scope.create_local(u32::as_elem()); - let input_shape_2 = scope.create_local(Elem::UInt); - let input_shape_3 = scope.create_local(Elem::UInt); + let input_shape_2 = scope.create_local(u32::as_elem()); + let input_shape_3 = scope.create_local(u32::as_elem()); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, input_stride_0 = stride(input, 0u32)); cpa!(scope, input_stride_1 = stride(input, 1u32)); @@ -63,10 +64,10 @@ impl InterpolateBicubicShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let h = scope.create_local(Elem::UInt); - let w = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let c = scope.create_local(u32::as_elem()); + let h = scope.create_local(u32::as_elem()); + let w = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -80,23 +81,23 @@ impl InterpolateBicubicShader { cpa!(scope, w = id / output_stride_3); cpa!(scope, w = w % output_shape_3); - let input_height = scope.create_local(Elem::UInt); - let output_height = scope.create_local(Elem::UInt); + let input_height = scope.create_local(u32::as_elem()); + let output_height = scope.create_local(u32::as_elem()); let output_height_float = scope.create_local(elem); - let input_width = scope.create_local(Elem::UInt); - let output_width = scope.create_local(Elem::UInt); + let input_width = scope.create_local(u32::as_elem()); + let output_width = scope.create_local(u32::as_elem()); let output_width_float = scope.create_local(elem); let frac = scope.create_local(elem); - let numerator = scope.create_local(Elem::UInt); + let numerator = scope.create_local(u32::as_elem()); let numerator_float = scope.create_local(elem); let not_zero = scope.create_local(Elem::Bool); let y_in_float = scope.create_local(elem); - let y_in = scope.create_local(Elem::UInt); + let y_in = scope.create_local(u32::as_elem()); let yw = scope.create_local(elem); - let y_tmp = scope.create_local(Elem::UInt); + let y_tmp = scope.create_local(u32::as_elem()); cpa!(scope, input_height = input_shape_2 - 1u32); cpa!(scope, output_height = output_shape_2 - 1u32); @@ -109,7 +110,7 @@ impl InterpolateBicubicShader { cpa!(scope, y_in = cast(y_in_float)); cpa!(scope, yw = frac - y_in_float); - let y0 = scope.zero(Elem::UInt); + let y0 = scope.zero(u32::as_elem()); cpa!(scope, not_zero = y_in != 0u32); cpa!(scope, if(not_zero).then(|scope|{ cpa!(scope, y0 = y_in - 1u32); @@ -124,9 +125,9 @@ impl InterpolateBicubicShader { let y3 = Self::min(scope, y_tmp, input_height); let x_in_float = scope.create_local(elem); - let x_in = scope.create_local(Elem::UInt); + let x_in = scope.create_local(u32::as_elem()); let xw = scope.create_local(elem); - let x_tmp = scope.create_local(Elem::UInt); + let x_tmp = scope.create_local(u32::as_elem()); cpa!(scope, input_width = input_shape_3 - 1u32); cpa!(scope, output_width = output_shape_3 - 1u32); @@ -139,7 +140,7 @@ impl InterpolateBicubicShader { cpa!(scope, x_in = cast(x_in_float)); cpa!(scope, xw = frac - x_in_float); - let x0 = scope.zero(Elem::UInt); + let x0 = scope.zero(u32::as_elem()); cpa!(scope, not_zero = x_in != 0u32); cpa!(scope, if(not_zero).then(|scope|{ cpa!(scope, x0 = x_in - 1u32); @@ -154,20 +155,20 @@ impl InterpolateBicubicShader { cpa!(scope, x_tmp = x_in + 2u32); let x3 = Self::min(scope, x_tmp, input_width); - let index_base = scope.create_local(Elem::UInt); - let index_tmp = scope.create_local(Elem::UInt); + let index_base = scope.create_local(u32::as_elem()); + let index_tmp = scope.create_local(u32::as_elem()); cpa!(scope, index_base = b * input_stride_0); cpa!(scope, index_tmp = c * input_stride_1); cpa!(scope, index_base += index_tmp); - let y0_stride = scope.create_local(Elem::UInt); - let y1_stride = scope.create_local(Elem::UInt); - let y2_stride = scope.create_local(Elem::UInt); - let y3_stride = scope.create_local(Elem::UInt); - let x0_stride = scope.create_local(Elem::UInt); - let x1_stride = scope.create_local(Elem::UInt); - let x2_stride = scope.create_local(Elem::UInt); - let x3_stride = scope.create_local(Elem::UInt); + let y0_stride = scope.create_local(u32::as_elem()); + let y1_stride = scope.create_local(u32::as_elem()); + let y2_stride = scope.create_local(u32::as_elem()); + let y3_stride = scope.create_local(u32::as_elem()); + let x0_stride = scope.create_local(u32::as_elem()); + let x1_stride = scope.create_local(u32::as_elem()); + let x2_stride = scope.create_local(u32::as_elem()); + let x3_stride = scope.create_local(u32::as_elem()); cpa!(scope, y0_stride = y0 * input_stride_2); cpa!(scope, y1_stride = y1 * input_stride_2); cpa!(scope, y2_stride = y2 * input_stride_2); @@ -177,10 +178,10 @@ impl InterpolateBicubicShader { cpa!(scope, x2_stride = x2 * input_stride_3); cpa!(scope, x3_stride = x3 * input_stride_3); - let index_0 = scope.create_local(Elem::UInt); - let index_1 = scope.create_local(Elem::UInt); - let index_2 = scope.create_local(Elem::UInt); - let index_3 = scope.create_local(Elem::UInt); + let index_0 = scope.create_local(u32::as_elem()); + let index_1 = scope.create_local(u32::as_elem()); + let index_2 = scope.create_local(u32::as_elem()); + let index_3 = scope.create_local(u32::as_elem()); let inp_0 = scope.create_local(input.item); let inp_1 = scope.create_local(input.item); let inp_2 = scope.create_local(input.item); diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index f897affcb..c446262bd 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -1,6 +1,7 @@ use cubecl::{ cpa, - ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -25,23 +26,23 @@ impl InterpolateBilinearShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let input_stride_0 = scope.create_local(Elem::UInt); - let input_stride_1 = scope.create_local(Elem::UInt); - let input_stride_2 = scope.create_local(Elem::UInt); - let input_stride_3 = scope.create_local(Elem::UInt); + let input_stride_0 = scope.create_local(u32::as_elem()); + let input_stride_1 = scope.create_local(u32::as_elem()); + let input_stride_2 = scope.create_local(u32::as_elem()); + let input_stride_3 = scope.create_local(u32::as_elem()); - let input_shape_2 = scope.create_local(Elem::UInt); - let input_shape_3 = scope.create_local(Elem::UInt); + let input_shape_2 = scope.create_local(u32::as_elem()); + let input_shape_3 = scope.create_local(u32::as_elem()); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, input_stride_0 = stride(input, 0u32)); cpa!(scope, input_stride_1 = stride(input, 1u32)); @@ -61,10 +62,10 @@ impl InterpolateBilinearShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let h = scope.create_local(Elem::UInt); - let w = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let c = scope.create_local(u32::as_elem()); + let h = scope.create_local(u32::as_elem()); + let w = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -80,22 +81,22 @@ impl InterpolateBilinearShader { let factor_float = scope.create_local(input.item); let numerator_float = scope.create_local(input.item); - let numerator_int = scope.create_local(Elem::UInt); + let numerator_int = scope.create_local(u32::as_elem()); let denominator_float = scope.create_local(input.item); - let denominator_int = scope.create_local(Elem::UInt); + let denominator_int = scope.create_local(u32::as_elem()); let frac = scope.create_local(input.item); let v0 = scope.create_local(input.item); let v1 = scope.create_local(input.item); let one = scope.create_with_value(1f32, input.item); - let y0 = scope.create_local(Elem::UInt); - let y1 = scope.create_local(Elem::UInt); + let y0 = scope.create_local(u32::as_elem()); + let y1 = scope.create_local(u32::as_elem()); let yw = scope.create_local(input.item); let yw_ = scope.create_local(input.item); - let x0 = scope.create_local(Elem::UInt); - let x1 = scope.create_local(Elem::UInt); + let x0 = scope.create_local(u32::as_elem()); + let x1 = scope.create_local(u32::as_elem()); let xw = scope.create_local(input.item); let xw_ = scope.create_local(input.item); @@ -129,13 +130,13 @@ impl InterpolateBilinearShader { cpa!(scope, x0 = cast(v0)); cpa!(scope, x1 = cast(v1)); - let index_base = scope.create_local(Elem::UInt); - let index_tmp = scope.create_local(Elem::UInt); - let index = scope.create_local(Elem::UInt); - let y0_stride = scope.create_local(Elem::UInt); - let y1_stride = scope.create_local(Elem::UInt); - let x0_stride = scope.create_local(Elem::UInt); - let x1_stride = scope.create_local(Elem::UInt); + let index_base = scope.create_local(u32::as_elem()); + let index_tmp = scope.create_local(u32::as_elem()); + let index = scope.create_local(u32::as_elem()); + let y0_stride = scope.create_local(u32::as_elem()); + let y1_stride = scope.create_local(u32::as_elem()); + let x0_stride = scope.create_local(u32::as_elem()); + let x1_stride = scope.create_local(u32::as_elem()); let p_a = scope.create_local(input.item); let p_b = scope.create_local(input.item); let p_c = scope.create_local(input.item); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 26c99e3be..f0d5bdf8f 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -1,6 +1,7 @@ use cubecl::{ cpa, - ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -27,23 +28,23 @@ impl InterpolateNearestShader { let id = Variable::builtin(Builtin::AbsolutePos); let elem = E::cube_elem(); - let input_stride_0 = scope.create_local(Elem::UInt); - let input_stride_1 = scope.create_local(Elem::UInt); - let input_stride_2 = scope.create_local(Elem::UInt); - let input_stride_3 = scope.create_local(Elem::UInt); + let input_stride_0 = scope.create_local(u32::as_elem()); + let input_stride_1 = scope.create_local(u32::as_elem()); + let input_stride_2 = scope.create_local(u32::as_elem()); + let input_stride_3 = scope.create_local(u32::as_elem()); - let input_shape_2 = scope.create_local(Elem::UInt); - let input_shape_3 = scope.create_local(Elem::UInt); + let input_shape_2 = scope.create_local(u32::as_elem()); + let input_shape_3 = scope.create_local(u32::as_elem()); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, input_stride_0 = stride(input, 0u32)); cpa!(scope, input_stride_1 = stride(input, 1u32)); @@ -63,10 +64,10 @@ impl InterpolateNearestShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let h = scope.create_local(Elem::UInt); - let w = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let c = scope.create_local(u32::as_elem()); + let h = scope.create_local(u32::as_elem()); + let w = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -85,8 +86,8 @@ impl InterpolateNearestShader { let denominator_float = scope.create_local(elem); let x = scope.create_local(elem); let y = scope.create_local(elem); - let xu = scope.create_local(Elem::UInt); - let yu = scope.create_local(Elem::UInt); + let xu = scope.create_local(u32::as_elem()); + let yu = scope.create_local(u32::as_elem()); cpa!(scope, factor_float = cast(h)); cpa!(scope, numerator_float = cast(input_shape_2)); @@ -104,8 +105,8 @@ impl InterpolateNearestShader { cpa!(scope, x = floor(x)); cpa!(scope, xu = cast(x)); - let index = scope.create_local(Elem::UInt); - let index_tmp = scope.create_local(Elem::UInt); + let index = scope.create_local(u32::as_elem()); + let index_tmp = scope.create_local(u32::as_elem()); let val = scope.create_local(output.item); cpa!(scope, index = b * input_stride_0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 57d602db6..1011daad0 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -1,6 +1,7 @@ use cubecl::{ cpa, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -26,25 +27,25 @@ impl InterpolateNearestBackwardShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let grad_stride_0 = scope.create_local(Elem::UInt); - let grad_stride_1 = scope.create_local(Elem::UInt); - let grad_stride_2 = scope.create_local(Elem::UInt); - let grad_stride_3 = scope.create_local(Elem::UInt); + let grad_stride_0 = scope.create_local(u32::as_elem()); + let grad_stride_1 = scope.create_local(u32::as_elem()); + let grad_stride_2 = scope.create_local(u32::as_elem()); + let grad_stride_3 = scope.create_local(u32::as_elem()); - let grad_shape_0 = scope.create_local(Elem::UInt); - let grad_shape_1 = scope.create_local(Elem::UInt); - let grad_shape_2 = scope.create_local(Elem::UInt); - let grad_shape_3 = scope.create_local(Elem::UInt); + let grad_shape_0 = scope.create_local(u32::as_elem()); + let grad_shape_1 = scope.create_local(u32::as_elem()); + let grad_shape_2 = scope.create_local(u32::as_elem()); + let grad_shape_3 = scope.create_local(u32::as_elem()); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, grad_stride_0 = stride(grad, 0u32)); cpa!(scope, grad_stride_1 = stride(grad, 1u32)); @@ -66,10 +67,10 @@ impl InterpolateNearestBackwardShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let oh = scope.create_local(Elem::UInt); - let ow = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let c = scope.create_local(u32::as_elem()); + let oh = scope.create_local(u32::as_elem()); + let ow = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -90,11 +91,11 @@ impl InterpolateNearestBackwardShader { let result = scope.create_local(grad.item); - let index_grad = scope.create_local(Elem::UInt); - let index_grad_0 = scope.create_local(Elem::UInt); - let index_grad_1 = scope.create_local(Elem::UInt); - let index_grad_2 = scope.create_local(Elem::UInt); - let index_grad_3 = scope.create_local(Elem::UInt); + let index_grad = scope.create_local(u32::as_elem()); + let index_grad_0 = scope.create_local(u32::as_elem()); + let index_grad_1 = scope.create_local(u32::as_elem()); + let index_grad_2 = scope.create_local(u32::as_elem()); + let index_grad_3 = scope.create_local(u32::as_elem()); cpa!(scope, index_grad_0 = b * grad_stride_0); cpa!(scope, index_grad_1 = c * grad_stride_1); @@ -135,7 +136,7 @@ impl InterpolateNearestBackwardShader { let elem = E::cube_elem(); let numerator_float = scope.create_local(elem); let div = scope.create_local(elem); - let index = scope.create_local(Elem::UInt); + let index = scope.create_local(u32::as_elem()); cpa!(scope, index = input_index * output_size); cpa!(scope, numerator_float = cast(index)); @@ -156,9 +157,9 @@ impl InterpolateNearestBackwardShader { let elem = E::cube_elem(); let numerator_float = scope.create_local(elem); let div = scope.create_local(elem); - let index = scope.create_local(Elem::UInt); + let index = scope.create_local(u32::as_elem()); let min = scope.create_local(Elem::Bool); - let end_index = scope.create_local(Elem::UInt); + let end_index = scope.create_local(u32::as_elem()); cpa!(scope, index = input_index + 1u32); cpa!(scope, index *= output_size); diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 1bd7187c5..b5804f9c3 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -23,7 +23,7 @@ pub struct MatmulAutotuneOperationSet { impl MatmulAutotuneOperationSet { fn new(lhs: JitTensor, rhs: JitTensor, out: JitTensor) -> Self { Self { - key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)), + key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())), lhs, rhs, out, diff --git a/crates/burn-jit/src/kernel/matmul/tune/key.rs b/crates/burn-jit/src/kernel/matmul/tune/key.rs index df5d04aad..5e830d2a7 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/key.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/key.rs @@ -1,5 +1,5 @@ use crate::tune::anchor; -use burn_tensor::Shape; +use burn_tensor::{DType, Shape}; use core::fmt::Debug; use serde::{Deserialize, Serialize}; use std::{cmp::max, fmt::Display, hash::Hash}; @@ -13,19 +13,21 @@ pub struct MatmulAutotuneKey { anchored_k: usize, anchored_n: usize, anchored_batch: usize, + dtype: DType, } impl Display for MatmulAutotuneKey { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str( format!( - "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}", + "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?} dtype:{:?}", self.round, self.broadcast, self.anchored_m, self.anchored_k, self.anchored_n, - self.anchored_batch + self.anchored_batch, + self.dtype ) .as_str(), ) @@ -34,7 +36,7 @@ impl Display for MatmulAutotuneKey { impl MatmulAutotuneKey { /// Create a matmul autotune key from the input shapes - pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self { + pub fn new(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { let ndims = lhs_shape.num_dims(); let m = lhs_shape.dims[ndims - 2]; let k = lhs_shape.dims[ndims - 1]; @@ -62,6 +64,7 @@ impl MatmulAutotuneKey { anchored_k: anchor(k, None), anchored_n: anchor(n, None), anchored_batch: anchor(batch_product, Some(256)), + dtype, } } } @@ -74,7 +77,7 @@ mod tests { fn matmul_autotune_key_all_same_and_round() { let lhs_shape: Shape = [4, 512, 512].into(); let rhs_shape: Shape = [4, 512, 512].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32); assert!(key.round); assert!(!key.broadcast); @@ -87,7 +90,7 @@ mod tests { fn matmul_autotune_key_all_different() { let lhs_shape: Shape = [2, 3, 511, 512].into(); let rhs_shape: Shape = [3, 2, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32); assert!(!key.round); assert!(key.broadcast); @@ -101,7 +104,7 @@ mod tests { fn matmul_autotune_key_large_batch() { let lhs_shape: Shape = [128, 512, 511, 512].into(); let rhs_shape: Shape = [200, 400, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape); + let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32); assert!(key.anchored_batch == 256); } diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index a6e62c846..38ef67054 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -10,6 +10,7 @@ use cubecl::{ ir::{ Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, }, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -36,23 +37,23 @@ impl AvgPool2dBackwardComputeShader { let output = self.output; let id = Variable::builtin(Builtin::AbsolutePos); - let grad_stride_0 = scope.create_local(Elem::UInt); - let grad_stride_1 = scope.create_local(Elem::UInt); - let grad_stride_2 = scope.create_local(Elem::UInt); - let grad_stride_3 = scope.create_local(Elem::UInt); + let grad_stride_0 = scope.create_local(u32::as_elem()); + let grad_stride_1 = scope.create_local(u32::as_elem()); + let grad_stride_2 = scope.create_local(u32::as_elem()); + let grad_stride_3 = scope.create_local(u32::as_elem()); - let grad_shape_2 = scope.create_local(Elem::UInt); - let grad_shape_3 = scope.create_local(Elem::UInt); + let grad_shape_2 = scope.create_local(u32::as_elem()); + let grad_shape_3 = scope.create_local(u32::as_elem()); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, grad_stride_0 = stride(grad, 0u32)); cpa!(scope, grad_stride_1 = stride(grad, 1u32)); @@ -72,16 +73,16 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); - let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); - let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); - let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); + let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem())); + let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem())); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem())); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem())); let [kernel_size_0, kernel_size_1] = self.kernel_size; - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let ih = scope.create_local(Elem::UInt); - let iw = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let c = scope.create_local(u32::as_elem()); + let ih = scope.create_local(u32::as_elem()); + let iw = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -95,16 +96,16 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, iw = id / output_stride_3); cpa!(scope, iw = iw % output_shape_3); - let index_current = scope.create_local(Elem::UInt); - let index_current_tmp = scope.create_local(Elem::UInt); + let index_current = scope.create_local(u32::as_elem()); + let index_current_tmp = scope.create_local(u32::as_elem()); cpa!(scope, index_current = ih * output_stride_2); cpa!(scope, index_current_tmp = iw * output_stride_3); cpa!(scope, index_current += index_current_tmp); - let index = scope.create_local(Elem::UInt); - let index_tmp = scope.create_local(Elem::UInt); - let index_base = scope.create_local(Elem::UInt); + let index = scope.create_local(u32::as_elem()); + let index_tmp = scope.create_local(u32::as_elem()); + let index_base = scope.create_local(u32::as_elem()); let grad_accumulation = scope.zero(grad.item); let result = scope.create_local(grad.item); @@ -130,14 +131,14 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, index_tmp = c * grad_stride_1); cpa!(scope, index_base += index_tmp); - let border_bottom = scope.create_local(Elem::UInt); - let border_right = scope.create_local(Elem::UInt); - let begin_h = scope.create_local(Elem::UInt); - let begin_w = scope.create_local(Elem::UInt); - let iw_start = scope.create_local(Elem::UInt); - let iw_end = scope.create_local(Elem::UInt); - let ih_start = scope.create_local(Elem::UInt); - let ih_end = scope.create_local(Elem::UInt); + let border_bottom = scope.create_local(u32::as_elem()); + let border_right = scope.create_local(u32::as_elem()); + let begin_h = scope.create_local(u32::as_elem()); + let begin_w = scope.create_local(u32::as_elem()); + let iw_start = scope.create_local(u32::as_elem()); + let iw_end = scope.create_local(u32::as_elem()); + let ih_start = scope.create_local(u32::as_elem()); + let ih_end = scope.create_local(u32::as_elem()); let after_start = scope.create_local(Elem::Bool); let before_end = scope.create_local(Elem::Bool); let contributed_h = scope.create_local(Elem::Bool); @@ -147,9 +148,9 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, begin_h = ih + padding_0); cpa!(scope, begin_w = iw + padding_1); - let ih_diff = scope.create_local(Elem::UInt); - let iw_diff = scope.create_local(Elem::UInt); - let count_int = scope.create_local(Elem::UInt); + let ih_diff = scope.create_local(u32::as_elem()); + let iw_diff = scope.create_local(u32::as_elem()); + let count_int = scope.create_local(u32::as_elem()); cpa!( scope, @@ -216,12 +217,12 @@ impl AvgPool2dBackwardComputeShader { output_stride_2: Variable, output_stride_3: Variable, ) -> (Variable, Variable, Variable, Variable) { - let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); - let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); - let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); - let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); - let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); - let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); + let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem())); + let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem())); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem())); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem())); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem())); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem())); let [kernel_size_0, kernel_size_1] = self.kernel_size; @@ -273,8 +274,8 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, oh_start_tmp = max(oh_start_tmp, 0i32)); cpa!(scope, ow_start_tmp = max(ow_start_tmp, 0i32)); - let oh_start = scope.create_local(Elem::UInt); - let ow_start = scope.create_local(Elem::UInt); + let oh_start = scope.create_local(u32::as_elem()); + let ow_start = scope.create_local(u32::as_elem()); cpa!(scope, oh_start = cast(oh_start_tmp)); cpa!(scope, ow_start = cast(ow_start_tmp)); @@ -285,11 +286,11 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, oh_end_tmp = max(kms_0, 0i32)); cpa!(scope, ow_end_tmp = max(kms_1, 0i32)); - let oh_end = scope.create_local(Elem::UInt); - let ow_end = scope.create_local(Elem::UInt); + let oh_end = scope.create_local(u32::as_elem()); + let ow_end = scope.create_local(u32::as_elem()); - let oh_end_limit = scope.create_local(Elem::UInt); - let ow_end_limit = scope.create_local(Elem::UInt); + let oh_end_limit = scope.create_local(u32::as_elem()); + let ow_end_limit = scope.create_local(u32::as_elem()); cpa!(scope, oh_end = cast(oh_end_tmp)); cpa!(scope, ow_end = cast(ow_end_tmp)); @@ -303,8 +304,8 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, oh_end = min(oh_end, oh_end_limit)); cpa!(scope, ow_end = min(ow_end, ow_end_limit)); - let index_current = scope.create_local(Elem::UInt); - let index_current_tmp = scope.create_local(Elem::UInt); + let index_current = scope.create_local(u32::as_elem()); + let index_current_tmp = scope.create_local(u32::as_elem()); cpa!(scope, index_current = ih * output_stride_2); cpa!(scope, index_current_tmp = iw * output_stride_3); @@ -340,7 +341,7 @@ impl Kernel for AvgPool2dBackwardEagerKernel visibility: Visibility::Read, }; let scalars = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: 6, }; let output = OutputInfo::Array { item }; diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 58dd09454..dc8830eac 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -10,6 +10,7 @@ use cubecl::{ ir::{ Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, }, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -36,23 +37,23 @@ impl MaxPool2dBackwardComputeShader { let indices = self.indices; let id = Variable::builtin(Builtin::AbsolutePos); - let grad_stride_0 = scope.create_local(Elem::UInt); - let grad_stride_1 = scope.create_local(Elem::UInt); - let grad_stride_2 = scope.create_local(Elem::UInt); - let grad_stride_3 = scope.create_local(Elem::UInt); + let grad_stride_0 = scope.create_local(u32::as_elem()); + let grad_stride_1 = scope.create_local(u32::as_elem()); + let grad_stride_2 = scope.create_local(u32::as_elem()); + let grad_stride_3 = scope.create_local(u32::as_elem()); - let grad_shape_2 = scope.create_local(Elem::UInt); - let grad_shape_3 = scope.create_local(Elem::UInt); + let grad_shape_2 = scope.create_local(u32::as_elem()); + let grad_shape_3 = scope.create_local(u32::as_elem()); - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); + let output_stride_0 = scope.create_local(u32::as_elem()); + let output_stride_1 = scope.create_local(u32::as_elem()); + let output_stride_2 = scope.create_local(u32::as_elem()); + let output_stride_3 = scope.create_local(u32::as_elem()); - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); + let output_shape_0 = scope.create_local(u32::as_elem()); + let output_shape_1 = scope.create_local(u32::as_elem()); + let output_shape_2 = scope.create_local(u32::as_elem()); + let output_shape_3 = scope.create_local(u32::as_elem()); cpa!(scope, grad_stride_0 = stride(grad, 0u32)); cpa!(scope, grad_stride_1 = stride(grad, 1u32)); @@ -72,10 +73,10 @@ impl MaxPool2dBackwardComputeShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let ih = scope.create_local(Elem::UInt); - let iw = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); + let c = scope.create_local(u32::as_elem()); + let ih = scope.create_local(u32::as_elem()); + let iw = scope.create_local(u32::as_elem()); cpa!(scope, b = id / output_stride_0); cpa!(scope, b = b % output_shape_0); @@ -89,8 +90,8 @@ impl MaxPool2dBackwardComputeShader { cpa!(scope, iw = id / output_stride_3); cpa!(scope, iw = iw % output_shape_3); - let index_current = scope.create_local(Elem::UInt); - let index_current_tmp = scope.create_local(Elem::UInt); + let index_current = scope.create_local(u32::as_elem()); + let index_current_tmp = scope.create_local(u32::as_elem()); cpa!(scope, index_current = ih * output_stride_2); cpa!(scope, index_current_tmp = iw * output_stride_3); @@ -98,12 +99,12 @@ impl MaxPool2dBackwardComputeShader { let index_select = scope.create_local(Elem::Int(IntKind::I32)); - let index_max = scope.create_local(Elem::UInt); + let index_max = scope.create_local(u32::as_elem()); let is_max = scope.create_local(Elem::Bool); - let index = scope.create_local(Elem::UInt); - let index_base = scope.create_local(Elem::UInt); - let index_tmp = scope.create_local(Elem::UInt); + let index = scope.create_local(u32::as_elem()); + let index_base = scope.create_local(u32::as_elem()); + let index_tmp = scope.create_local(u32::as_elem()); let grad_accumulation = scope.zero(grad.item); let result = scope.create_local(grad.item); @@ -162,12 +163,12 @@ impl MaxPool2dBackwardComputeShader { output_stride_2: Variable, output_stride_3: Variable, ) -> (Variable, Variable, Variable, Variable) { - let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); - let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); - let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); - let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); - let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); - let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); + let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem())); + let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem())); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem())); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem())); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem())); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem())); let [kernel_size_0, kernel_size_1] = self.kernel_size; @@ -219,8 +220,8 @@ impl MaxPool2dBackwardComputeShader { cpa!(scope, oh_start_tmp = max(oh_start_tmp, 0i32)); cpa!(scope, ow_start_tmp = max(ow_start_tmp, 0i32)); - let oh_start = scope.create_local(Elem::UInt); - let ow_start = scope.create_local(Elem::UInt); + let oh_start = scope.create_local(u32::as_elem()); + let ow_start = scope.create_local(u32::as_elem()); cpa!(scope, oh_start = cast(oh_start_tmp)); cpa!(scope, ow_start = cast(ow_start_tmp)); @@ -231,11 +232,11 @@ impl MaxPool2dBackwardComputeShader { cpa!(scope, oh_end_tmp = max(kms_0, 0i32)); cpa!(scope, ow_end_tmp = max(kms_1, 0i32)); - let oh_end = scope.create_local(Elem::UInt); - let ow_end = scope.create_local(Elem::UInt); + let oh_end = scope.create_local(u32::as_elem()); + let ow_end = scope.create_local(u32::as_elem()); - let oh_end_limit = scope.create_local(Elem::UInt); - let ow_end_limit = scope.create_local(Elem::UInt); + let oh_end_limit = scope.create_local(u32::as_elem()); + let ow_end_limit = scope.create_local(u32::as_elem()); cpa!(scope, oh_end = cast(oh_end_tmp)); cpa!(scope, ow_end = cast(ow_end_tmp)); @@ -249,8 +250,8 @@ impl MaxPool2dBackwardComputeShader { cpa!(scope, oh_end = min(oh_end, oh_end_limit)); cpa!(scope, ow_end = min(ow_end, ow_end_limit)); - let index_current = scope.create_local(Elem::UInt); - let index_current_tmp = scope.create_local(Elem::UInt); + let index_current = scope.create_local(u32::as_elem()); + let index_current_tmp = scope.create_local(u32::as_elem()); cpa!(scope, index_current = ih * output_stride_2); cpa!(scope, index_current_tmp = iw * output_stride_3); @@ -295,7 +296,7 @@ impl Kernel for MaxPool2dWithIndicesBackwardEagerK visibility: Visibility::Read, }; let scalars = InputInfo::Scalar { - elem: Elem::UInt, + elem: u32::as_elem(), size: 6, }; let output = OutputInfo::Array { item }; diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 527c85d79..3c5e2bbac 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -1,6 +1,6 @@ use cubecl::{ cpa, - ir::{Builtin, Elem, Item, Scope, Variable, VariableKind}, + ir::{Builtin, Item, Scope, Variable, VariableKind}, prelude::*, CubeCountSettings, Execution, InputInfo, OutputInfo, }; @@ -60,10 +60,10 @@ impl, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel::new(); @@ -80,7 +80,7 @@ impl, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel, E: JitElement> PrngShader { let cube_count_y = Variable::builtin(Builtin::CubeCountY); let local_index = Variable::builtin(Builtin::UnitPos); - let n_invocations = scope.create_local(Elem::UInt); + let n_invocations = scope.create_local(u32::as_elem()); cpa!(scope, n_invocations = cube_dim_x); cpa!(scope, n_invocations *= cube_dim_y); - let cube_offset = scope.create_local(Elem::UInt); + let cube_offset = scope.create_local(u32::as_elem()); cpa!(scope, cube_offset = cube_pos_x * cube_count_y); cpa!(scope, cube_offset += cube_pos_y); cpa!(scope, cube_offset *= n_invocations); - let write_index_base = scope.create_local(Elem::UInt); + let write_index_base = scope.create_local(u32::as_elem()); cpa!(scope, write_index_base = cube_offset); cpa!(scope, write_index_base *= n_values_per_thread); cpa!(scope, write_index_base += local_index); // Set state with unique seeds - let thread_seed = scope.create_local(Elem::UInt); + let thread_seed = scope.create_local(u32::as_elem()); cpa!(scope, thread_seed = cast(1000000007)); - let thread_seed_index = scope.create_local(Elem::UInt); + let thread_seed_index = scope.create_local(u32::as_elem()); cpa!(scope, thread_seed_index = cube_offset + local_index); cpa!(scope, thread_seed *= thread_seed_index); - let state_0 = scope.create_local(Elem::UInt); + let state_0 = scope.create_local(u32::as_elem()); cpa!(scope, state_0 = thread_seed); cpa!(scope, state_0 += seed_0); - let state_1 = scope.create_local(Elem::UInt); + let state_1 = scope.create_local(u32::as_elem()); cpa!(scope, state_1 = thread_seed); cpa!(scope, state_1 += seed_1); - let state_2 = scope.create_local(Elem::UInt); + let state_2 = scope.create_local(u32::as_elem()); cpa!(scope, state_2 = thread_seed); cpa!(scope, state_2 += seed_2); - let state_3 = scope.create_local(Elem::UInt); + let state_3 = scope.create_local(u32::as_elem()); cpa!(scope, state_3 = thread_seed); cpa!(scope, state_3 += seed_3); @@ -260,7 +260,7 @@ fn taus_step( s3: Variable, m: Variable, ) { - let b = scope.create_local(Elem::UInt); + let b = scope.create_local(u32::as_elem()); cpa!(scope, b = z << s1); cpa!(scope, b = b ^ z); cpa!(scope, b = b >> s2); diff --git a/crates/burn-jit/src/kernel/prng/bernoulli.rs b/crates/burn-jit/src/kernel/prng/bernoulli.rs index 15053b3fa..6ab8551ae 100644 --- a/crates/burn-jit/src/kernel/prng/bernoulli.rs +++ b/crates/burn-jit/src/kernel/prng/bernoulli.rs @@ -2,6 +2,7 @@ use burn_tensor::Shape; use cubecl::{ cpa, ir::{Elem, FloatKind, Scope, Variable}, + prelude::*, }; use crate::{ @@ -34,7 +35,14 @@ impl Prng for Bernoulli { output: Variable, ) { let float_elem = Elem::Float(FloatKind::F32); - let prob = args[0]; + let mut prob = args[0]; + + if prob.item.elem() != Elem::Float(FloatKind::F32) { + let prob_f32 = scope.create_local(float_elem); + cpa!(scope, prob_f32 = cast(prob)); + prob = prob_f32; + } + cpa!( scope, range(0u32, n_values_per_thread).for_each(|i, scope| { @@ -43,7 +51,7 @@ impl Prng for Bernoulli { taus_step_2(scope, state_2); lcg_step(scope, state_3); - let int_random = scope.create_local(Elem::UInt); + let int_random = scope.create_local(u32::as_elem()); cpa!(scope, int_random = state_0 ^ state_1); cpa!(scope, int_random = int_random ^ state_2); cpa!(scope, int_random = int_random ^ state_3); @@ -54,7 +62,7 @@ impl Prng for Bernoulli { let bernoulli = scope.create_local(Elem::Bool); cpa!(scope, bernoulli = float_random < prob); - let write_index = scope.create_local(Elem::UInt); + let write_index = scope.create_local(u32::as_elem()); cpa!(scope, write_index = i * n_invocations); cpa!(scope, write_index += write_index_base); cpa!(scope, output[write_index] = bernoulli); diff --git a/crates/burn-jit/src/kernel/prng/normal.rs b/crates/burn-jit/src/kernel/prng/normal.rs index e9b4b428a..e3466267d 100644 --- a/crates/burn-jit/src/kernel/prng/normal.rs +++ b/crates/burn-jit/src/kernel/prng/normal.rs @@ -1,6 +1,7 @@ use cubecl::{ cpa, ir::{Elem, FloatKind, Scope, Variable}, + prelude::*, }; use std::f32::consts::PI; @@ -47,7 +48,7 @@ impl Prng for Normal { cpa!( scope, range(0u32, n_values_per_thread / 2).for_each(|i, scope| { - let int_random = scope.create_local(Elem::UInt); + let int_random = scope.create_local(u32::as_elem()); // First random uniform integer taus_step_0(scope, state_0); @@ -95,9 +96,9 @@ impl Prng for Normal { cpa!(scope, normal_1 += mean); // Write to output - let write_index_0 = scope.create_local(Elem::UInt); - let write_index_1 = scope.create_local(Elem::UInt); - let iteration_offset = scope.create_local(Elem::UInt); + let write_index_0 = scope.create_local(u32::as_elem()); + let write_index_1 = scope.create_local(u32::as_elem()); + let iteration_offset = scope.create_local(u32::as_elem()); cpa!(scope, write_index_0 = write_index_base); cpa!(scope, iteration_offset = two * i); cpa!(scope, iteration_offset *= n_invocations); diff --git a/crates/burn-jit/src/kernel/prng/uniform.rs b/crates/burn-jit/src/kernel/prng/uniform.rs index 15490758d..7d0467325 100644 --- a/crates/burn-jit/src/kernel/prng/uniform.rs +++ b/crates/burn-jit/src/kernel/prng/uniform.rs @@ -2,6 +2,7 @@ use burn_tensor::Shape; use cubecl::{ cpa, ir::{Elem, FloatKind, Scope, Variable}, + prelude::*, }; use crate::{ @@ -49,7 +50,7 @@ impl Prng for Uniform { taus_step_2(scope, state_2); lcg_step(scope, state_3); - let int_random = scope.create_local(Elem::UInt); + let int_random = scope.create_local(u32::as_elem()); cpa!(scope, int_random = state_0 ^ state_1); cpa!(scope, int_random = int_random ^ state_2); cpa!(scope, int_random = int_random ^ state_3); @@ -65,7 +66,7 @@ impl Prng for Uniform { cpa!(scope, uniform = cast(uniform_float)); cpa!(scope, uniform += lower_bound); - let write_index = scope.create_local(Elem::UInt); + let write_index = scope.create_local(u32::as_elem()); cpa!(scope, write_index = i * n_invocations); cpa!(scope, write_index += write_index_base); cpa!(scope, output[write_index] = uniform); diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs index 1b10a5f10..df8374f63 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs @@ -3,6 +3,7 @@ use burn_tensor::cast::ToElement; use cubecl::{ cpa, ir::{Elem, Item, Scope, Variable}, + prelude::*, }; use super::base::ReduceDimShared; @@ -17,7 +18,7 @@ impl ReduceDimShared for Argmax { input_item: Item, ) -> Self::Accumulator { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); - let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); + let index_shared_memory = scope.create_shared(u32::as_elem(), shared_memory_size); let max = input_item .elem() .constant_from_f64(ToElement::to_f64(&E::minimum_value())); diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs index fa8ff21af..b64235752 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs @@ -2,6 +2,7 @@ use burn_tensor::cast::ToElement; use cubecl::{ cpa, ir::{Elem, Item, Scope, Variable}, + prelude::*, }; use crate::{kernel::reduce::Argmin, JitElement}; @@ -18,7 +19,7 @@ impl ReduceDimShared for Argmin { input_item: Item, ) -> Self::Accumulator { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); - let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); + let index_shared_memory = scope.create_shared(u32::as_elem(), shared_memory_size); let min = input_item .elem() .constant_from_f64(ToElement::to_f64(&E::maximum_value())); diff --git a/crates/burn-jit/src/kernel/reduce/shared/shader.rs b/crates/burn-jit/src/kernel/reduce/shared/shader.rs index 3aa48ff8a..a74f51fbe 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/shader.rs @@ -2,6 +2,7 @@ use cubecl::{ cpa, ir::{Builtin, KernelDefinition, VariableKind}, prelude::CubeCount, + prelude::*, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -120,38 +121,38 @@ impl> SharedReduceDimComputeShader let cube_dim_x = Variable::builtin(Builtin::CubeDimX); let cube_dim_y = Variable::builtin(Builtin::CubeDimY); - let stride_reduce_dim_input = scope.create_local(Elem::UInt); + let stride_reduce_dim_input = scope.create_local(u32::as_elem()); cpa!(scope, stride_reduce_dim_input = stride(tensor, dim)); - let shape_reduce_dim_input = scope.create_local(Elem::UInt); + let shape_reduce_dim_input = scope.create_local(u32::as_elem()); cpa!(scope, shape_reduce_dim_input = shape(tensor, dim)); // To determine which reduce_group (not position, but absolute id) - let reduce_group_id = scope.create_local(Elem::UInt); + let reduce_group_id = scope.create_local(u32::as_elem()); cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x); cpa!(scope, reduce_group_id += cube_pos_x); // nth thread in the cube - let local_id = scope.create_local(Elem::UInt); + let local_id = scope.create_local(u32::as_elem()); cpa!(scope, local_id = local_invocation_id_y * cube_dim_x); cpa!(scope, local_id += local_invocation_id_x); - let n_threads = scope.create_local(Elem::UInt); + let n_threads = scope.create_local(u32::as_elem()); cpa!(scope, n_threads = cube_dim_x * cube_dim_y); - let index_offset = scope.zero(Elem::UInt); + let index_offset = scope.zero(u32::as_elem()); cpa!( scope, range(0u32, rank).for_each(|i, scope| { - let stride_input = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_output = scope.create_local(Elem::UInt); + let stride_input = scope.create_local(u32::as_elem()); + let stride_output = scope.create_local(u32::as_elem()); + let shape_output = scope.create_local(u32::as_elem()); cpa!(scope, stride_input = stride(tensor, i)); cpa!(scope, stride_output = stride(output, i)); cpa!(scope, shape_output = shape(output, i)); - let num_block = scope.create_local(Elem::UInt); + let num_block = scope.create_local(u32::as_elem()); cpa!(scope, num_block = reduce_group_id / stride_output); cpa!(scope, num_block = num_block % shape_output); cpa!(scope, num_block = num_block * stride_input); @@ -166,14 +167,14 @@ impl> SharedReduceDimComputeShader cpa!( scope, range(0u32, self.n_input_values_per_thread).for_each(|i, scope| { - let nth = scope.create_local(Elem::UInt); + let nth = scope.create_local(u32::as_elem()); cpa!(scope, nth = i * n_threads); cpa!(scope, nth += local_id); let within_shape = scope.create_local(Elem::Bool); if self.divisible_shape { - let current_position = scope.create_local(Elem::UInt); + let current_position = scope.create_local(u32::as_elem()); cpa!(scope, current_position = nth * stride_reduce_dim_input); cpa!(scope, current_position += index_offset); @@ -182,7 +183,7 @@ impl> SharedReduceDimComputeShader } else { cpa!(scope, within_shape = nth < shape_reduce_dim_input); cpa!(scope, if(within_shape).then(|scope|{ - let current_position = scope.create_local(Elem::UInt); + let current_position = scope.create_local(u32::as_elem()); cpa!(scope, current_position = nth * stride_reduce_dim_input); cpa!(scope, current_position += index_offset); @@ -208,7 +209,7 @@ impl> SharedReduceDimComputeShader let updating_thread = scope.create_local(Elem::Bool); cpa!(scope, updating_thread = local_id < n_threads); cpa!(scope, if(updating_thread).then(|scope|{ - let read_position = scope.create_local(Elem::UInt); + let read_position = scope.create_local(u32::as_elem()); cpa!(scope, read_position = n_threads + local_id); let read_value = RD::read_from_shared(scope, shared_memory, read_position); diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs index c79114d7b..3977f528d 100644 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ b/crates/burn-jit/src/kernel/reduce/tune/base.rs @@ -44,6 +44,7 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> &input.shape, &input.strides, reduce_dim, + EI::dtype(), )), input, output, diff --git a/crates/burn-jit/src/kernel/reduce/tune/key.rs b/crates/burn-jit/src/kernel/reduce/tune/key.rs index efcdc5915..d03ede699 100644 --- a/crates/burn-jit/src/kernel/reduce/tune/key.rs +++ b/crates/burn-jit/src/kernel/reduce/tune/key.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::{cmp::min, fmt::Display}; -use burn_tensor::Shape; +use burn_tensor::{DType, Shape}; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] /// Autotune key representative of reduce versions @@ -9,14 +9,15 @@ pub struct ReduceAutotuneKey { pub(crate) reduce_dim_length: usize, pub(crate) reduce_dim_stride: usize, pub(crate) others_product: usize, + dtype: DType, } impl Display for ReduceAutotuneKey { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str( format!( - "Reduce - reduce_dim_length: {:?} reduce_dim_stride: {:?} others_product: {:?}", - self.reduce_dim_length, self.reduce_dim_stride, self.others_product + "Reduce - reduce_dim_length: {:?} reduce_dim_stride: {:?} others_product: {:?} dtype: {:?}", + self.reduce_dim_length, self.reduce_dim_stride, self.others_product, self.dtype ) .as_str(), ) @@ -25,7 +26,7 @@ impl Display for ReduceAutotuneKey { impl ReduceAutotuneKey { /// Create a reduce autotune key from the input shape and reduce dim - pub fn new(shape: &Shape, strides: &[usize], reduce_dim: usize) -> Self { + pub fn new(shape: &Shape, strides: &[usize], reduce_dim: usize, dtype: DType) -> Self { let ndims = strides.len(); let reduce_dim_length = shape.dims[reduce_dim]; let reduce_dim_stride = strides[reduce_dim]; @@ -39,6 +40,7 @@ impl ReduceAutotuneKey { reduce_dim_length: anchor(reduce_dim_length, None), reduce_dim_stride: anchor(reduce_dim_stride, None), others_product: anchor(others_product, None), + dtype, } } } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 1576d51c9..ef04513a9 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -21,6 +21,14 @@ pub(crate) async fn into_data(tensor: JitTensor(tensor: JitTensor) -> TensorData { + let tensor = kernel::into_contiguous(tensor); + + let bytes = tensor.client.read(tensor.handle.binding()); + TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) +} + pub(crate) async fn bool_into_data(tensor: JitTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_async(tensor.handle.binding()).await; diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index 8730f7098..3528f9487 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -37,8 +37,12 @@ pub use serial_test; #[macro_export] macro_rules! testgen_all { () => { + use burn_tensor::{Float, Int}; + $crate::testgen_all!([Float], [Int]); + }; + ([$($float:ident),*], [$($int:ident),*]) => { mod jit { - burn_jit::testgen_jit!(); + burn_jit::testgen_jit!([$($float),*], [$($int),*]); mod kernel { use super::*; @@ -79,7 +83,7 @@ macro_rules! testgen_all { } } mod jit_fusion { - burn_jit::testgen_jit_fusion!(); + burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*]); } }; } @@ -87,22 +91,32 @@ macro_rules! testgen_all { #[macro_export] macro_rules! testgen_jit { () => { - use super::*; + use burn_tensor::{Float, Int}; + $crate::testgen_jit!([Float], [Int]); + }; + ([$($float:ident),*], [$($int:ident),*]) => { + pub use super::*; use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test}; pub type TestBackend = JitBackend; + pub type TestBackend2 = JitBackend; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; + pub type TestTensor2 = burn_tensor::Tensor, D>; pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensorInt2 = + burn_tensor::Tensor, D, burn_tensor::Int>; pub type TestTensorBool = burn_tensor::Tensor; + pub type TestTensorBool2 = + burn_tensor::Tensor, D, burn_tensor::Bool>; pub type ReferenceTensor = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!([$($float),*], [$($int),*]); + burn_autodiff::testgen_all!([$($float),*]); // Not all ops are implemented for quantization yet, notably missing: // `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand` @@ -111,28 +125,38 @@ macro_rules! testgen_jit { burn_tensor::testgen_calibration!(); burn_tensor::testgen_scheme!(); burn_tensor::testgen_quantize!(); - }; + } } #[macro_export] macro_rules! testgen_jit_fusion { () => { + use burn_tensor::{Float, Int}; + $crate::testgen_jit_fusion!([Float], [Int]); + }; + ([$($float:ident),*], [$($int:ident),*]) => { use super::*; use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor}; pub type TestBackend = burn_fusion::Fusion>; + pub type TestBackend2 = burn_fusion::Fusion>; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; + pub type TestTensor2 = burn_tensor::Tensor, D>; pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensorInt2 = + burn_tensor::Tensor, D, burn_tensor::Int>; pub type TestTensorBool = burn_tensor::Tensor; + pub type TestTensorBool2 = + burn_tensor::Tensor, D, burn_tensor::Bool>; pub type ReferenceTensor = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!([$($float),*], [$($int),*]); + burn_autodiff::testgen_all!([$($float),*]); // Not all ops are implemented for quantization yet, notably missing: // `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand` diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index cc7e14e9b..9c4dfa3a4 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -2,20 +2,25 @@ authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Tensor library with user-friendly APIs and automatic differentiation support" +documentation = "https://docs.rs/burn-tensor" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-tensor" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor" -documentation = "https://docs.rs/burn-tensor" version.workspace = true [features] +cubecl = ["dep:cubecl"] +cubecl-cuda = ["cubecl", "cubecl/cuda"] +cubecl-hip = ["cubecl", "cubecl/hip"] +cubecl-wgpu = ["cubecl", "cubecl/wgpu"] default = ["std", "repr"] doc = ["default"] experimental-named-tensor = [] -export_tests = ["burn-tensor-testgen"] +export_tests = ["burn-tensor-testgen", "cubecl"] +repr = [] std = [ "rand/std", "half/std", @@ -24,24 +29,19 @@ std = [ "burn-common/rayon", "colored", ] -repr = [] -cubecl = ["dep:cubecl"] -cubecl-wgpu = ["cubecl", "cubecl/wgpu"] -cubecl-cuda = ["cubecl", "cubecl/cuda"] -cubecl-hip = ["cubecl", "cubecl/hip"] [dependencies] burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } -cubecl = { workspace = true, optional = true } +cubecl = { workspace = true, optional = true, default-features = true } +bytemuck = { workspace = true } +colored = { workspace = true, optional = true } derive-new = { workspace = true } half = { workspace = true, features = ["bytemuck"] } num-traits = { workspace = true } rand = { workspace = true } rand_distr = { workspace = true } # use instead of statrs because it supports no_std -bytemuck = { workspace = true } -colored = { workspace = true, optional = true } # The same implementation of HashMap in std but with no_std support (only needs alloc crate) hashbrown = { workspace = true } # no_std compatible @@ -50,6 +50,7 @@ hashbrown = { workspace = true } # no_std compatible serde = { workspace = true } serde_bytes = { workspace = true } + [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic-util = { workspace = true } diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 2692b21b4..8b5adfbbf 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -20,7 +20,7 @@ pub mod repr; #[cfg(feature = "export_tests")] #[allow(missing_docs)] -mod tests; +pub mod tests; pub use half::{bf16, f16}; pub(crate) use tensor::check::macros::check; @@ -30,7 +30,7 @@ pub use burn_common::reader::*; // Useful so that backends don't have to add `bu #[cfg(feature = "cubecl")] mod cube { - use cubecl::ir::{Elem, FloatKind, IntKind}; + use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind}; impl From for cubecl::ir::Elem { fn from(dtype: crate::DType) -> Self { @@ -41,11 +41,12 @@ mod cube { crate::DType::BF16 => Elem::Float(FloatKind::BF16), crate::DType::I64 => Elem::Int(IntKind::I64), crate::DType::I32 => Elem::Int(IntKind::I32), - crate::DType::I16 => panic!("i16 isn't supported yet."), - crate::DType::I8 => panic!("i8 isn't supported yet."), - crate::DType::U64 => Elem::UInt, - crate::DType::U32 => Elem::UInt, - crate::DType::U8 => panic!("u8 isn't supported yet."), + crate::DType::I16 => Elem::Int(IntKind::I16), + crate::DType::I8 => Elem::Int(IntKind::I8), + crate::DType::U64 => Elem::UInt(UIntKind::U64), + crate::DType::U32 => Elem::UInt(UIntKind::U32), + crate::DType::U16 => Elem::UInt(UIntKind::U16), + crate::DType::U8 => Elem::UInt(UIntKind::U8), crate::DType::Bool => Elem::Bool, crate::DType::QFloat(_) => panic!("quantized type is not supported yet."), } diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 36df4163e..3fcf5e403 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -1,4 +1,7 @@ -use core::any::{Any, TypeId}; +use core::{ + any::{Any, TypeId}, + f32, +}; use alloc::boxed::Box; use alloc::format; @@ -181,6 +184,11 @@ impl TensorData { .map(|e: &i64| e.elem::()), ), DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::())), + DType::U16 => Box::new( + bytemuck::checked::cast_slice(&self.bytes) + .iter() + .map(|e: &u16| e.elem::()), + ), DType::U32 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() @@ -313,6 +321,7 @@ impl TensorData { DType::I8 => self.convert_inplace::(), DType::U64 => self.convert_inplace::(), DType::U32 => self.convert_inplace::(), + DType::U16 => self.convert_inplace::(), DType::U8 => self.convert_inplace::(), DType::Bool | DType::QFloat(_) => unreachable!(), } @@ -419,6 +428,7 @@ impl TensorData { DType::I8 => self.assert_eq_elem::(other), DType::U64 => self.assert_eq_elem::(other), DType::U32 => self.assert_eq_elem::(other), + DType::U16 => self.assert_eq_elem::(other), DType::U8 => self.assert_eq_elem::(other), DType::Bool => self.assert_eq_elem::(other), DType::QFloat(q) => { @@ -511,9 +521,21 @@ impl TensorData { continue; } - let err = ((a - b).pow(2.0f64)).sqrt(); + let err = (a - b).abs(); - if err > tolerance || err.is_nan() { + if self.dtype.is_float() { + if let Some((err, tolerance)) = compare_floats(a, b, self.dtype, tolerance) { + // Only print the first 5 different values. + if num_diff < max_num_diff { + message += format!( + "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ + {tolerance}" + ) + .as_str(); + } + num_diff += 1; + } + } else if err > tolerance || err.is_nan() { // Only print the first 5 different values. if num_diff < max_num_diff { message += format!( @@ -683,6 +705,7 @@ impl core::fmt::Display for TensorData { DType::I8 => format!("{:?}", self.as_slice::().unwrap()), DType::U64 => format!("{:?}", self.as_slice::().unwrap()), DType::U32 => format!("{:?}", self.as_slice::().unwrap()), + DType::U16 => format!("{:?}", self.as_slice::().unwrap()), DType::U8 => format!("{:?}", self.as_slice::().unwrap()), DType::Bool => format!("{:?}", self.as_slice::().unwrap()), DType::QFloat(q) => match &q { @@ -869,7 +892,7 @@ impl Data { } #[allow(deprecated)] -impl + Clone + core::fmt::Debug + PartialEq, const D: usize> Data { +impl + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data { /// Asserts the data is approximately equal to another data. /// /// # Arguments @@ -926,9 +949,21 @@ impl + Clone + core::fmt::Debug + PartialEq, const D: usize> Data tolerance || err.is_nan() { + if E::dtype().is_float() { + if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) { + // Only print the first 5 different values. + if num_diff < max_num_diff { + message += format!( + "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ + {tolerance}" + ) + .as_str(); + } + num_diff += 1; + } + } else if err > tolerance || err.is_nan() { // Only print the first 5 different values. if num_diff < max_num_diff { message += format!( @@ -1076,6 +1111,30 @@ impl core::fmt::Display for Data { } } +fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> { + let epsilon_deviations = tolerance / f32::EPSILON as f64; + let epsilon = match ty { + DType::F64 => f32::EPSILON as f64, // Don't increase precision beyond `f32`, see below + DType::F32 => f32::EPSILON as f64, + DType::F16 => half::f16::EPSILON.to_f64(), + DType::BF16 => half::bf16::EPSILON.to_f64(), + _ => unreachable!(), + }; + let tolerance_norm = epsilon_deviations * epsilon; + // Clamp to 1.0 so we don't require more precision than `tolerance`. This is because literals + // have a fixed number of digits, so increasing precision breaks things + let value_abs = value.abs().max(1.0); + let tolerance_adjusted = tolerance_norm * value_abs; + + let err = (value - other).abs(); + + if err > tolerance_adjusted || err.is_nan() { + Some((err, tolerance_adjusted)) + } else { + None + } +} + #[cfg(test)] #[allow(deprecated)] mod tests { @@ -1140,16 +1199,16 @@ mod tests { #[test] fn should_assert_appox_eq_limit() { let data1 = TensorData::from([[3.0, 5.0, 6.0]]); - let data2 = TensorData::from([[3.01, 5.0, 6.0]]); + let data2 = TensorData::from([[3.03, 5.0, 6.0]]); data1.assert_approx_eq(&data2, 2); } #[test] #[should_panic] - fn should_assert_appox_eq_above_limit() { + fn should_assert_approx_eq_above_limit() { let data1 = TensorData::from([[3.0, 5.0, 6.0]]); - let data2 = TensorData::from([[3.011, 5.0, 6.0]]); + let data2 = TensorData::from([[3.031, 5.0, 6.0]]); data1.assert_approx_eq(&data2, 2); } diff --git a/crates/burn-tensor/src/tensor/element/base.rs b/crates/burn-tensor/src/tensor/element/base.rs index bf08e6ad1..f5721ec46 100644 --- a/crates/burn-tensor/src/tensor/element/base.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -1,6 +1,8 @@ use core::cmp::Ordering; use crate::{cast::ToElement, quantization::QuantizationStrategy, Distribution}; +#[cfg(feature = "cubecl")] +use cubecl::flex32; use half::{bf16, f16}; use rand::RngCore; use serde::{Deserialize, Serialize}; @@ -193,6 +195,14 @@ make_element!( dtype DType::I16 ); +make_element!( + ty u16 Precision::Half, + convert |elem: &dyn ToElement| elem.to_u16(), + random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), + cmp |a: &u16, b: &u16| Ord::cmp(a, b), + dtype DType::U16 +); + make_element!( ty i8 Precision::Other, convert |elem: &dyn ToElement| elem.to_i8(), @@ -230,6 +240,18 @@ make_element!( dtype DType::BF16 ); +#[cfg(feature = "cubecl")] +make_element!( + ty flex32 Precision::Half, + convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()), + random |distribution: Distribution, rng: &mut R| { + let sample: f32 = distribution.sampler(rng).sample(); + flex32::from_elem(sample) + }, + cmp |a: &flex32, b: &flex32| a.total_cmp(b), + dtype DType::F32 +); + make_element!( ty bool Precision::Other, convert |elem: &dyn ToElement| elem.to_u8() != 0, @@ -254,6 +276,7 @@ pub enum DType { I8, U64, U32, + U16, U8, Bool, QFloat(QuantizationStrategy), @@ -273,6 +296,7 @@ impl DType { DType::I8 => core::mem::size_of::(), DType::U64 => core::mem::size_of::(), DType::U32 => core::mem::size_of::(), + DType::U16 => core::mem::size_of::(), DType::U8 => core::mem::size_of::(), DType::Bool => core::mem::size_of::(), DType::QFloat(strategy) => match strategy { diff --git a/crates/burn-tensor/src/tensor/element/cast.rs b/crates/burn-tensor/src/tensor/element/cast.rs index 6cb687c3e..cd26c0ceb 100644 --- a/crates/burn-tensor/src/tensor/element/cast.rs +++ b/crates/burn-tensor/src/tensor/element/cast.rs @@ -442,6 +442,50 @@ impl ToElement for bf16 { } } +#[cfg(feature = "cubecl")] +impl ToElement for cubecl::flex32 { + #[inline] + fn to_i64(&self) -> i64 { + Self::to_f32(*self).to_i64() + } + #[inline] + fn to_u64(&self) -> u64 { + Self::to_f32(*self).to_u64() + } + #[inline] + fn to_i8(&self) -> i8 { + Self::to_f32(*self).to_i8() + } + #[inline] + fn to_u8(&self) -> u8 { + Self::to_f32(*self).to_u8() + } + #[inline] + fn to_i16(&self) -> i16 { + Self::to_f32(*self).to_i16() + } + #[inline] + fn to_u16(&self) -> u16 { + Self::to_f32(*self).to_u16() + } + #[inline] + fn to_i32(&self) -> i32 { + Self::to_f32(*self).to_i32() + } + #[inline] + fn to_u32(&self) -> u32 { + Self::to_f32(*self).to_u32() + } + #[inline] + fn to_f32(&self) -> f32 { + Self::to_f32(*self) + } + #[inline] + fn to_f64(&self) -> f64 { + Self::to_f64(*self) + } +} + impl ToElement for bool { #[inline] fn to_i64(&self) -> i64 { diff --git a/crates/burn-tensor/src/tests/activation/hard_sigmoid.rs b/crates/burn-tensor/src/tests/activation/hard_sigmoid.rs index 57ad86c67..917dc2588 100644 --- a/crates/burn-tensor/src/tests/activation/hard_sigmoid.rs +++ b/crates/burn-tensor/src/tests/activation/hard_sigmoid.rs @@ -15,7 +15,7 @@ mod tests { #[test] fn test_hard_sigmoid_overflow() { - let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]); + let tensor = TestTensor::<1>::from([FloatType::MAX, FloatType::MIN]); let output = activation::hard_sigmoid(tensor, 0.2, 0.5); let expected = TensorData::from([1.0, 0.0]); diff --git a/crates/burn-tensor/src/tests/activation/leaky_relu.rs b/crates/burn-tensor/src/tests/activation/leaky_relu.rs index 5f57174ea..bdb71bce9 100644 --- a/crates/burn-tensor/src/tests/activation/leaky_relu.rs +++ b/crates/burn-tensor/src/tests/activation/leaky_relu.rs @@ -9,9 +9,9 @@ mod tests { let output = activation::leaky_relu(tensor, 0.01); - output.into_data().assert_eq( - &TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]), - false, - ); + // Account for conversion errors if `FloatType != f32` + output + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]), 5); } } diff --git a/crates/burn-tensor/src/tests/activation/log_sigmoid.rs b/crates/burn-tensor/src/tests/activation/log_sigmoid.rs index 7cc1ae6e5..87d28c5ac 100644 --- a/crates/burn-tensor/src/tests/activation/log_sigmoid.rs +++ b/crates/burn-tensor/src/tests/activation/log_sigmoid.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(log_sigmoid)] mod tests { use super::*; - use burn_tensor::{activation, Tensor, TensorData}; + use burn_tensor::{activation, cast::ToElement, tests::Numeric, Tensor, TensorData}; #[test] fn test_log_sigmoid() { @@ -10,7 +10,7 @@ mod tests { let output = activation::log_sigmoid(tensor); let expected = TensorData::from([[-3.132617e-1, -9.114665e-4], [-2.260327e-6, -3.0485873]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } #[test] @@ -23,9 +23,10 @@ mod tests { let expected = TensorData::from([0.0, -300.0]); output.into_data().assert_approx_eq(&expected, 4); - let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]); + let tensor = + TestTensor::<1>::from([::MAX, ::MIN]); let output = activation::log_sigmoid(tensor); - let expected = TensorData::from([0.0, f32::MIN]); + let expected = TensorData::from([0.0, ::MIN.to_f32()]); output.into_data().assert_approx_eq(&expected, 4); } diff --git a/crates/burn-tensor/src/tests/activation/mish.rs b/crates/burn-tensor/src/tests/activation/mish.rs index 53feab8bb..2f251bb3f 100644 --- a/crates/burn-tensor/src/tests/activation/mish.rs +++ b/crates/burn-tensor/src/tests/activation/mish.rs @@ -11,6 +11,6 @@ mod tests { let output = activation::mish(tensor); let expected = TensorData::from([[-0.1971, -0.3006, -0.1172], [-0.2413, 0.5823, -0.0888]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/activation/sigmoid.rs b/crates/burn-tensor/src/tests/activation/sigmoid.rs index 2d42fa4ed..8d4731e2b 100644 --- a/crates/burn-tensor/src/tests/activation/sigmoid.rs +++ b/crates/burn-tensor/src/tests/activation/sigmoid.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(sigmoid)] mod tests { use super::*; - use burn_tensor::{activation, Tensor, TensorData}; + use burn_tensor::{activation, tests::Numeric, Tensor, TensorData}; #[test] fn test_sigmoid() { @@ -10,12 +10,12 @@ mod tests { let output = activation::sigmoid(tensor); let expected = TensorData::from([[0.7311, 0.9991], [1.0, 0.0474]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } #[test] fn test_sigmoid_overflow() { - let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]); + let tensor = TestTensor::<1>::from([FloatType::MAX, FloatType::MIN]); let output = activation::sigmoid(tensor); let expected = TensorData::from([1.0, 0.0]); diff --git a/crates/burn-tensor/src/tests/activation/silu.rs b/crates/burn-tensor/src/tests/activation/silu.rs index 097f815f8..1002c592e 100644 --- a/crates/burn-tensor/src/tests/activation/silu.rs +++ b/crates/burn-tensor/src/tests/activation/silu.rs @@ -10,6 +10,6 @@ mod tests { let output = activation::silu(tensor); let expected = TensorData::from([[0.7311, 1.7616], [2.8577, 3.9281]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/activation/softmax.rs b/crates/burn-tensor/src/tests/activation/softmax.rs index 233dff1d7..8ef5b1a14 100644 --- a/crates/burn-tensor/src/tests/activation/softmax.rs +++ b/crates/burn-tensor/src/tests/activation/softmax.rs @@ -10,6 +10,6 @@ mod tests { let output = activation::softmax(tensor, 1); let expected = TensorData::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/activation/softmin.rs b/crates/burn-tensor/src/tests/activation/softmin.rs index 62a140188..7ba8b4fc4 100644 --- a/crates/burn-tensor/src/tests/activation/softmin.rs +++ b/crates/burn-tensor/src/tests/activation/softmin.rs @@ -10,6 +10,6 @@ mod tests { let output = activation::softmin(tensor, 1); let expected = TensorData::from([[9.9753e-01, 2.4726e-03], [1.1254e-07, 1.0000e+00]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/activation/softplus.rs b/crates/burn-tensor/src/tests/activation/softplus.rs index d99768126..8eee383a1 100644 --- a/crates/burn-tensor/src/tests/activation/softplus.rs +++ b/crates/burn-tensor/src/tests/activation/softplus.rs @@ -5,19 +5,17 @@ mod tests { #[test] fn test_softplus_d2() { - let tensor = Tensor::::from([ - [-0.4240, -0.9574, -0.2215], - [-0.5767, 0.7218, -0.1620], - ]); + let tensor = + TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]); let output = activation::softplus(tensor.clone(), 1.0); let expected = TensorData::from([[0.5034, 0.3249, 0.5885], [0.4458, 1.1178, 0.6154]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); let output = activation::softplus(tensor, 2.0); let expected = TensorData::from([[0.1782, 0.0687, 0.2480], [0.1371, 0.8277, 0.2721]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/activation/tanh_activation.rs b/crates/burn-tensor/src/tests/activation/tanh_activation.rs index e76d044e6..49b1a74e6 100644 --- a/crates/burn-tensor/src/tests/activation/tanh_activation.rs +++ b/crates/burn-tensor/src/tests/activation/tanh_activation.rs @@ -10,6 +10,6 @@ mod tests { let output = activation::tanh(tensor); let expected = TensorData::from([[0.7616, 0.9640], [0.9951, 0.9993]]); - output.into_data().assert_approx_eq(&expected, 4); + output.into_data().assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index a4834049d..60f5521ab 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -5,122 +5,59 @@ mod ops; mod quantization; mod stats; +pub use cubecl::prelude::{Float, Int, Numeric}; + #[allow(missing_docs)] #[macro_export] macro_rules! testgen_all { () => { - // test activation - burn_tensor::testgen_gelu!(); - burn_tensor::testgen_mish!(); - burn_tensor::testgen_relu!(); - burn_tensor::testgen_leaky_relu!(); - burn_tensor::testgen_softmax!(); - burn_tensor::testgen_softmin!(); - burn_tensor::testgen_softplus!(); - burn_tensor::testgen_sigmoid!(); - burn_tensor::testgen_log_sigmoid!(); - burn_tensor::testgen_silu!(); - burn_tensor::testgen_tanh_activation!(); + pub mod tensor { + pub use super::*; - // test module - burn_tensor::testgen_module_forward!(); - burn_tensor::testgen_module_conv1d!(); - burn_tensor::testgen_module_conv2d!(); - burn_tensor::testgen_module_conv3d!(); - burn_tensor::testgen_module_deform_conv2d!(); - burn_tensor::testgen_module_conv_transpose1d!(); - burn_tensor::testgen_module_conv_transpose2d!(); - burn_tensor::testgen_module_conv_transpose3d!(); - burn_tensor::testgen_module_unfold4d!(); - burn_tensor::testgen_module_max_pool1d!(); - burn_tensor::testgen_module_max_pool2d!(); - burn_tensor::testgen_module_avg_pool1d!(); - burn_tensor::testgen_module_avg_pool2d!(); - burn_tensor::testgen_module_adaptive_avg_pool1d!(); - burn_tensor::testgen_module_adaptive_avg_pool2d!(); - burn_tensor::testgen_module_nearest_interpolate!(); - burn_tensor::testgen_module_bilinear_interpolate!(); - burn_tensor::testgen_module_bicubic_interpolate!(); + pub type FloatType = ::FloatElem; + pub type IntType = ::IntElem; + pub type BoolType = ::BoolTensorPrimitive; - // test ops - burn_tensor::testgen_add!(); - burn_tensor::testgen_aggregation!(); - burn_tensor::testgen_arange!(); - burn_tensor::testgen_arange_step!(); - burn_tensor::testgen_arg!(); - burn_tensor::testgen_cast!(); - burn_tensor::testgen_cat!(); - burn_tensor::testgen_chunk!(); - burn_tensor::testgen_clamp!(); - burn_tensor::testgen_close!(); - burn_tensor::testgen_cos!(); - burn_tensor::testgen_create_like!(); - burn_tensor::testgen_div!(); - burn_tensor::testgen_erf!(); - burn_tensor::testgen_exp!(); - burn_tensor::testgen_flatten!(); - burn_tensor::testgen_full!(); - burn_tensor::testgen_gather_scatter!(); - burn_tensor::testgen_init!(); - burn_tensor::testgen_iter_dim!(); - burn_tensor::testgen_log!(); - burn_tensor::testgen_log1p!(); - burn_tensor::testgen_map_comparison!(); - burn_tensor::testgen_mask!(); - burn_tensor::testgen_matmul!(); - burn_tensor::testgen_maxmin!(); - burn_tensor::testgen_mul!(); - burn_tensor::testgen_narrow!(); - burn_tensor::testgen_neg!(); - burn_tensor::testgen_one_hot!(); - burn_tensor::testgen_powf_scalar!(); - burn_tensor::testgen_random!(); - burn_tensor::testgen_recip!(); - burn_tensor::testgen_repeat_dim!(); - burn_tensor::testgen_repeat!(); - burn_tensor::testgen_reshape!(); - burn_tensor::testgen_select!(); - burn_tensor::testgen_sin!(); - burn_tensor::testgen_slice!(); - burn_tensor::testgen_stack!(); - burn_tensor::testgen_sqrt!(); - burn_tensor::testgen_abs!(); - burn_tensor::testgen_squeeze!(); - burn_tensor::testgen_sub!(); - burn_tensor::testgen_tanh!(); - burn_tensor::testgen_transpose!(); - burn_tensor::testgen_tri!(); - burn_tensor::testgen_powf!(); - burn_tensor::testgen_any!(); - burn_tensor::testgen_all_op!(); - burn_tensor::testgen_permute!(); - burn_tensor::testgen_movedim!(); - burn_tensor::testgen_flip!(); - burn_tensor::testgen_bool!(); - burn_tensor::testgen_argwhere_nonzero!(); - burn_tensor::testgen_sign!(); - burn_tensor::testgen_expand!(); - burn_tensor::testgen_tri_mask!(); - burn_tensor::testgen_sort_argsort!(); - burn_tensor::testgen_topk!(); - burn_tensor::testgen_remainder!(); - burn_tensor::testgen_cartesian_grid!(); - burn_tensor::testgen_nan!(); - burn_tensor::testgen_round!(); - burn_tensor::testgen_floor!(); - burn_tensor::testgen_ceil!(); + $crate::testgen_with_float_param!(); + $crate::testgen_no_param!(); + } + }; + ([$($float:ident),*], [$($int:ident),*]) => { + pub mod tensor { + pub use super::*; - // test stats - burn_tensor::testgen_var!(); - burn_tensor::testgen_cov!(); - burn_tensor::testgen_eye!(); - burn_tensor::testgen_display!(); + pub type FloatType = ::FloatElem; + pub type IntType = ::IntElem; + pub type BoolType = ::BoolTensorPrimitive; - // test clone invariance - burn_tensor::testgen_clone_invariance!(); + ::paste::paste! { + $(mod [<$float _ty>] { + pub use super::*; - // test padding - burn_tensor::testgen_padding!(); + pub type TestBackend = TestBackend2<$float, IntType>; + pub type TestTensor = TestTensor2<$float, IntType, D>; + pub type TestTensorInt = TestTensorInt2<$float, IntType, D>; + pub type TestTensorBool = TestTensorBool2<$float, IntType, D>; + + pub type FloatType = $float; + + $crate::testgen_with_float_param!(); + })* + $(mod [<$int _ty>] { + pub use super::*; + + pub type TestBackend = TestBackend2; + pub type TestTensor = TestTensor2; + pub type TestTensorInt = TestTensorInt2; + pub type TestTensorBool = TestTensorBool2; + + pub type IntType = $int; + + $crate::testgen_with_int_param!(); + })* + } + $crate::testgen_no_param!(); + } }; } @@ -178,3 +115,191 @@ macro_rules! testgen_quantization { burn_tensor::testgen_q_transpose!(); }; } + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_with_float_param { + () => { + // test activation + burn_tensor::testgen_gelu!(); + burn_tensor::testgen_mish!(); + burn_tensor::testgen_relu!(); + burn_tensor::testgen_leaky_relu!(); + burn_tensor::testgen_softmax!(); + burn_tensor::testgen_softmin!(); + burn_tensor::testgen_softplus!(); + burn_tensor::testgen_sigmoid!(); + burn_tensor::testgen_log_sigmoid!(); + burn_tensor::testgen_silu!(); + burn_tensor::testgen_tanh_activation!(); + + // test module + burn_tensor::testgen_module_conv1d!(); + burn_tensor::testgen_module_conv2d!(); + burn_tensor::testgen_module_conv3d!(); + burn_tensor::testgen_module_forward!(); + burn_tensor::testgen_module_deform_conv2d!(); + burn_tensor::testgen_module_conv_transpose1d!(); + burn_tensor::testgen_module_conv_transpose2d!(); + burn_tensor::testgen_module_conv_transpose3d!(); + burn_tensor::testgen_module_unfold4d!(); + burn_tensor::testgen_module_max_pool1d!(); + burn_tensor::testgen_module_max_pool2d!(); + burn_tensor::testgen_module_avg_pool1d!(); + burn_tensor::testgen_module_avg_pool2d!(); + burn_tensor::testgen_module_adaptive_avg_pool1d!(); + burn_tensor::testgen_module_adaptive_avg_pool2d!(); + burn_tensor::testgen_module_nearest_interpolate!(); + burn_tensor::testgen_module_bilinear_interpolate!(); + burn_tensor::testgen_module_bicubic_interpolate!(); + + // test ops + burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_narrow!(); + burn_tensor::testgen_add!(); + burn_tensor::testgen_aggregation!(); + burn_tensor::testgen_arange!(); + burn_tensor::testgen_arange_step!(); + burn_tensor::testgen_arg!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_chunk!(); + burn_tensor::testgen_clamp!(); + burn_tensor::testgen_close!(); + burn_tensor::testgen_cos!(); + burn_tensor::testgen_create_like!(); + burn_tensor::testgen_div!(); + burn_tensor::testgen_erf!(); + burn_tensor::testgen_exp!(); + burn_tensor::testgen_flatten!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_init!(); + burn_tensor::testgen_iter_dim!(); + burn_tensor::testgen_log!(); + burn_tensor::testgen_log1p!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_matmul!(); + burn_tensor::testgen_maxmin!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_neg!(); + burn_tensor::testgen_one_hot!(); + burn_tensor::testgen_powf_scalar!(); + burn_tensor::testgen_random!(); + burn_tensor::testgen_recip!(); + burn_tensor::testgen_repeat_dim!(); + burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_sin!(); + burn_tensor::testgen_slice!(); + burn_tensor::testgen_stack!(); + burn_tensor::testgen_sqrt!(); + burn_tensor::testgen_abs!(); + burn_tensor::testgen_squeeze!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_tanh!(); + burn_tensor::testgen_transpose!(); + burn_tensor::testgen_tri!(); + burn_tensor::testgen_powf!(); + burn_tensor::testgen_any!(); + burn_tensor::testgen_all_op!(); + burn_tensor::testgen_permute!(); + burn_tensor::testgen_movedim!(); + burn_tensor::testgen_flip!(); + burn_tensor::testgen_bool!(); + burn_tensor::testgen_argwhere_nonzero!(); + burn_tensor::testgen_sign!(); + burn_tensor::testgen_expand!(); + burn_tensor::testgen_tri_mask!(); + burn_tensor::testgen_sort_argsort!(); + burn_tensor::testgen_topk!(); + burn_tensor::testgen_remainder!(); + burn_tensor::testgen_cartesian_grid!(); + burn_tensor::testgen_nan!(); + burn_tensor::testgen_round!(); + burn_tensor::testgen_floor!(); + burn_tensor::testgen_ceil!(); + burn_tensor::testgen_select!(); + + // test stats + burn_tensor::testgen_var!(); + burn_tensor::testgen_cov!(); + burn_tensor::testgen_eye!(); + + // test padding + burn_tensor::testgen_padding!(); + }; +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_with_int_param { + () => { + // test ops + burn_tensor::testgen_add!(); + burn_tensor::testgen_aggregation!(); + burn_tensor::testgen_arg!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_bool!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_div!(); + burn_tensor::testgen_expand!(); + burn_tensor::testgen_flip!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_movedim!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_permute!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_select!(); + burn_tensor::testgen_sign!(); + burn_tensor::testgen_sort_argsort!(); + burn_tensor::testgen_stack!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_transpose!(); + burn_tensor::testgen_gather_scatter!(); + + // test stats + burn_tensor::testgen_eye!(); + + // test padding + burn_tensor::testgen_padding!(); + }; +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_no_param { + () => { + // test stats + burn_tensor::testgen_display!(); + + // test clone invariance + burn_tensor::testgen_clone_invariance!(); + }; +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! as_bytes { + ($ty:ident: $($elem:expr),*) => { + F::as_bytes(&[$($ty::new($elem),)*]) + }; +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! as_type { + ($ty:ident: [$($elem:tt),*]) => { + [$($crate::as_type![$ty: $elem]),*] + }; + ($ty:ident: [$($elem:tt,)*]) => { + [$($crate::as_type![$ty: $elem]),*] + }; + ($ty:ident: $elem:expr) => { + { + use $crate::tests::{Float, Int}; + + $ty::new($elem) + } + }; +} diff --git a/crates/burn-tensor/src/tests/module/avgpool2d.rs b/crates/burn-tensor/src/tests/module/avgpool2d.rs index d44f5d625..d2d9aa35b 100644 --- a/crates/burn-tensor/src/tests/module/avgpool2d.rs +++ b/crates/burn-tensor/src/tests/module/avgpool2d.rs @@ -106,7 +106,7 @@ mod tests { self.count_include_pad, ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq(&output.into_data(), 2); } } } diff --git a/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs b/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs index fb345a2ed..a875c1ad1 100644 --- a/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs +++ b/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs @@ -111,7 +111,7 @@ mod tests { !output .clone() .to_data() - .as_slice::() + .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), @@ -148,7 +148,7 @@ mod tests { InterpolateOptions::new(InterpolateMode::Bicubic), ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq_diff(&output.into_data(), 0.3); } } } diff --git a/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs b/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs index 1b077ff7c..34cbefb65 100644 --- a/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs +++ b/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs @@ -111,7 +111,7 @@ mod tests { !output .clone() .to_data() - .as_slice::() + .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), @@ -156,7 +156,7 @@ mod tests { InterpolateOptions::new(InterpolateMode::Bilinear), ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq_diff(&output.into_data(), 0.3); } } } diff --git a/crates/burn-tensor/src/tests/module/deform_conv2d.rs b/crates/burn-tensor/src/tests/module/deform_conv2d.rs index ffbfe38c9..04da4704c 100644 --- a/crates/burn-tensor/src/tests/module/deform_conv2d.rs +++ b/crates/burn-tensor/src/tests/module/deform_conv2d.rs @@ -26,7 +26,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [[0.9074, 0.6387], [0.5160, 0.4196]], [[2.4259, 1.8008], [1.5449, 1.3112]], [[3.9444, 2.9629], [2.5738, 2.2027]], @@ -55,7 +55,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([ + test.assert_output(TestTensor::<4>::from([ [ [[0.2155, 0.1928], [0.1934, 0.1755]], [[0.7251, 0.6759], [0.6877, 0.6485]], @@ -93,7 +93,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [[0.1018, 0.0658], [0.0467, 0.0362]], [[0.4125, 0.3367], [0.3069, 0.2824]], [[1.3076, 1.0242], [0.9025, 0.8000]], @@ -123,7 +123,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [[1.0794, 0.7676], [0.7209, 0.5337]], [[2.7059, 2.0216], [1.9740, 1.5419]], [[4.3325, 3.2755], [3.2271, 2.5501]], @@ -153,7 +153,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [[1.0669], [0.6329]], [[2.9741], [2.0383]], [[4.8812], [3.4437]], @@ -180,7 +180,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [ [ 0.1998, 0.3762, 0.5285, 0.6053, 0.3844, 0.1987, 0.0481, 0.0000, @@ -264,7 +264,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [[1.0647], [0.5783]], [[2.9289], [1.8829]], [[4.7931], [3.1875]], @@ -291,7 +291,7 @@ mod tests { width: 5, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [[0.6162], [0.7611], [0.4666]], [[1.8578], [2.2684], [1.6208]], [[3.0994], [3.7757], [2.7749]], @@ -318,7 +318,7 @@ mod tests { width: 4, }; - test.assert_output(Tensor::::from([[ + test.assert_output(TestTensor::<4>::from([[ [ [0.8909, 0.6016], [1.0697, 0.7186], @@ -389,29 +389,29 @@ mod tests { out_width, ]); let device = Default::default(); - let weight = Tensor::::from( + let weight = TestTensor::<4>::from( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<4, _>(shape_weight.clone()) .into_data(), ) .div_scalar(shape_weight.num_elements() as f32); - let bias = Tensor::::from( + let bias = TestTensor::<1>::from( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), ) .div_scalar(self.channels_out as f32); - let x = Tensor::::from( + let x = TestTensor::<4>::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x.clone()) .into_data(), ) .div_scalar(shape_x.num_elements() as f32); - let offset = Tensor::::from( + let offset = TestTensor::<4>::from( TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device) .reshape::<4, _>(shape_offset.clone()) .into_data(), ) .div_scalar(shape_offset.num_elements() as f32); - let mask = Tensor::::from( + let mask = TestTensor::<4>::from( TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device) .reshape::<4, _>(shape_mask.clone()) .into_data(), @@ -433,7 +433,7 @@ mod tests { ), ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq_diff(&output.into_data(), 0.04); } } } diff --git a/crates/burn-tensor/src/tests/module/forward.rs b/crates/burn-tensor/src/tests/module/forward.rs index a3207b8fa..36b6da4aa 100644 --- a/crates/burn-tensor/src/tests/module/forward.rs +++ b/crates/burn-tensor/src/tests/module/forward.rs @@ -7,8 +7,8 @@ mod tests { fn test_embedding_forward() { let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TensorData::from([[0, 1], [1, 1]]); - let weights = Tensor::::from(weights); - let indices = Tensor::::from(indices); + let weights = TestTensor::<2>::from(weights); + let indices = TestTensorInt::<2>::from(indices); let output = embedding(weights, indices); let expected = TensorData::from([ diff --git a/crates/burn-tensor/src/tests/module/nearest_interpolate.rs b/crates/burn-tensor/src/tests/module/nearest_interpolate.rs index 72c480a71..1585a90b1 100644 --- a/crates/burn-tensor/src/tests/module/nearest_interpolate.rs +++ b/crates/burn-tensor/src/tests/module/nearest_interpolate.rs @@ -84,7 +84,7 @@ mod tests { !output .clone() .to_data() - .as_slice::() + .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), @@ -122,7 +122,7 @@ mod tests { InterpolateOptions::new(InterpolateMode::Nearest), ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data().assert_approx_eq_diff(&output.into_data(), 0.2); } } } diff --git a/crates/burn-tensor/src/tests/ops/arange.rs b/crates/burn-tensor/src/tests/ops/arange.rs index 5f096762b..f16ff7819 100644 --- a/crates/burn-tensor/src/tests/ops/arange.rs +++ b/crates/burn-tensor/src/tests/ops/arange.rs @@ -8,24 +8,24 @@ mod tests { fn test_arange() { let device = ::Device::default(); - let tensor = Tensor::::arange(2..5, &device); + let tensor = TestTensorInt::<1>::arange(2..5, &device); tensor .into_data() .assert_eq(&TensorData::from([2, 3, 4]), false); // Test arange with negative numbers - let tensor = Tensor::::arange(-10..-5, &device); + let tensor = TestTensorInt::<1>::arange(-10..-5, &device); tensor .into_data() .assert_eq(&TensorData::from([-10, -9, -8, -7, -6]), false); - let tensor = Tensor::::arange(-3..0, &device); + let tensor = TestTensorInt::<1>::arange(-3..0, &device); tensor .into_data() .assert_eq(&TensorData::from([-3, -2, -1]), false); // Test arange with a mix of positive and negative numbers - let tensor = Tensor::::arange(-2..3, &device); + let tensor = TestTensorInt::<1>::arange(-2..3, &device); tensor .clone() .into_data() diff --git a/crates/burn-tensor/src/tests/ops/arange_step.rs b/crates/burn-tensor/src/tests/ops/arange_step.rs index cdb428449..92d656125 100644 --- a/crates/burn-tensor/src/tests/ops/arange_step.rs +++ b/crates/burn-tensor/src/tests/ops/arange_step.rs @@ -9,28 +9,28 @@ mod tests { let device = ::Device::default(); // Test correct sequence of numbers when the range is 0..9 and the step is 1 - let tensor = Tensor::::arange_step(0..9, 1, &device); + let tensor = TestTensorInt::<1>::arange_step(0..9, 1, &device); tensor .into_data() .assert_eq(&TensorData::from([0, 1, 2, 3, 4, 5, 6, 7, 8]), false); // Test correct sequence of numbers when the range is 0..3 and the step is 2 - let tensor = Tensor::::arange_step(0..3, 2, &device); + let tensor = TestTensorInt::<1>::arange_step(0..3, 2, &device); tensor .into_data() .assert_eq(&TensorData::from([0, 2]), false); // Test correct sequence of numbers when the range is 0..2 and the step is 5 - let tensor = Tensor::::arange_step(0..2, 5, &device); + let tensor = TestTensorInt::<1>::arange_step(0..2, 5, &device); tensor.into_data().assert_eq(&TensorData::from([0]), false); // Test correct sequence of numbers when the range includes negative numbers - let tensor = Tensor::::arange_step(-3..3, 2, &device); + let tensor = TestTensorInt::<1>::arange_step(-3..3, 2, &device); tensor .into_data() .assert_eq(&TensorData::from([-3, -1, 1]), false); - let tensor = Tensor::::arange_step(-5..1, 5, &device); + let tensor = TestTensorInt::<1>::arange_step(-5..1, 5, &device); tensor .clone() .into_data() @@ -43,6 +43,6 @@ mod tests { fn should_panic_when_step_is_zero() { let device = ::Device::default(); // Test that arange_step panics when the step is 0 - let _tensor = Tensor::::arange_step(0..3, 0, &device); + let _tensor = TestTensorInt::<1>::arange_step(0..3, 0, &device); } } diff --git a/crates/burn-tensor/src/tests/ops/cartesian_grid.rs b/crates/burn-tensor/src/tests/ops/cartesian_grid.rs index dea916ed8..7c4bbb340 100644 --- a/crates/burn-tensor/src/tests/ops/cartesian_grid.rs +++ b/crates/burn-tensor/src/tests/ops/cartesian_grid.rs @@ -9,15 +9,14 @@ mod tests { let device = ::Device::default(); // Test a single element tensor - let tensor: Tensor = - Tensor::::cartesian_grid([1], &device); + let tensor: Tensor = TestTensorInt::<1>::cartesian_grid([1], &device); tensor .into_data() .assert_eq(&TensorData::from([[0]]), false); // Test for a 2x2 tensor let tensor: Tensor = - Tensor::::cartesian_grid([2, 2], &device); + TestTensorInt::<2>::cartesian_grid([2, 2], &device); tensor.into_data().assert_eq( &TensorData::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]), false, diff --git a/crates/burn-tensor/src/tests/ops/cast.rs b/crates/burn-tensor/src/tests/ops/cast.rs index ead9aaa11..14601d428 100644 --- a/crates/burn-tensor/src/tests/ops/cast.rs +++ b/crates/burn-tensor/src/tests/ops/cast.rs @@ -31,9 +31,7 @@ mod tests { #[test] fn cast_bool_to_float_tensor() { - let tensor = - Tensor::::from([[true, false, true], [false, false, true]]) - .float(); + let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).float(); let expected = TensorData::from([[1., 0., 1.], [0., 0., 1.]]); diff --git a/crates/burn-tensor/src/tests/ops/cat.rs b/crates/burn-tensor/src/tests/ops/cat.rs index 945ad93dc..0f404eccd 100644 --- a/crates/burn-tensor/src/tests/ops/cat.rs +++ b/crates/burn-tensor/src/tests/ops/cat.rs @@ -18,8 +18,8 @@ mod tests { #[test] fn should_support_cat_ops_int() { let device = Default::default(); - let tensor_1 = Tensor::::from_data([[1, 2, 3]], &device); - let tensor_2 = Tensor::::from_data([[4, 5, 6]], &device); + let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device); + let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device); let output = Tensor::cat(vec![tensor_1, tensor_2], 0); @@ -31,8 +31,8 @@ mod tests { #[test] fn should_support_cat_ops_bool() { let device = Default::default(); - let tensor_1 = Tensor::::from_data([[false, true, true]], &device); - let tensor_2 = Tensor::::from_data([[true, true, false]], &device); + let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device); + let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device); let output = Tensor::cat(vec![tensor_1, tensor_2], 0); diff --git a/crates/burn-tensor/src/tests/ops/ceil.rs b/crates/burn-tensor/src/tests/ops/ceil.rs index 7d5e53455..af68529b8 100644 --- a/crates/burn-tensor/src/tests/ops/ceil.rs +++ b/crates/burn-tensor/src/tests/ops/ceil.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_ceil_ops() { let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.ceil(); let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]); diff --git a/crates/burn-tensor/src/tests/ops/clamp.rs b/crates/burn-tensor/src/tests/ops/clamp.rs index 438ab62af..39bda2a00 100644 --- a/crates/burn-tensor/src/tests/ops/clamp.rs +++ b/crates/burn-tensor/src/tests/ops/clamp.rs @@ -8,7 +8,7 @@ mod tests { let device = Default::default(); // test float tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.clamp_min(2.0); @@ -18,7 +18,7 @@ mod tests { // test int tensor let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensorInt::<2>::from_data(data, &device); let output = tensor.clamp_min(2); output @@ -31,7 +31,7 @@ mod tests { let device = Default::default(); // test float tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.clamp_max(2.0); @@ -41,7 +41,7 @@ mod tests { // test int tensor let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensorInt::<2>::from_data(data, &device); let output = tensor.clamp_max(4); output @@ -54,7 +54,7 @@ mod tests { let device = Default::default(); // test float tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.clamp(1.0, 4.0); output @@ -63,7 +63,7 @@ mod tests { // test int tensor let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensorInt::<2>::from_data(data, &device); let output = tensor.clamp(1, 4); output diff --git a/crates/burn-tensor/src/tests/ops/cos.rs b/crates/burn-tensor/src/tests/ops/cos.rs index f26ca88f3..7872ca966 100644 --- a/crates/burn-tensor/src/tests/ops/cos.rs +++ b/crates/burn-tensor/src/tests/ops/cos.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_cos_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cos(); let expected = TensorData::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); diff --git a/crates/burn-tensor/src/tests/ops/div.rs b/crates/burn-tensor/src/tests/ops/div.rs index 2bce2cd4d..0c5ebe5aa 100644 --- a/crates/burn-tensor/src/tests/ops/div.rs +++ b/crates/burn-tensor/src/tests/ops/div.rs @@ -8,8 +8,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); @@ -22,8 +22,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; @@ -38,7 +38,7 @@ mod tests { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor / scalar; @@ -52,8 +52,8 @@ mod tests { let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let data_2 = TensorData::from([[1, 1, 2], [1, 1, 2]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; @@ -67,8 +67,8 @@ mod tests { let data_1 = TensorData::from([[0, 1, 2]]); let data_2 = TensorData::from([[1, 1, 2], [3, 4, 5]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; @@ -81,7 +81,7 @@ mod tests { fn should_support_div_scalar_ops_int() { let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let scalar = 2; - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor / scalar; diff --git a/crates/burn-tensor/src/tests/ops/erf.rs b/crates/burn-tensor/src/tests/ops/erf.rs index 1219b5210..c768e648c 100644 --- a/crates/burn-tensor/src/tests/ops/erf.rs +++ b/crates/burn-tensor/src/tests/ops/erf.rs @@ -6,18 +6,18 @@ mod tests { #[test] fn should_support_erf_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.erf(); let expected = TensorData::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 2); } #[test] fn should_support_erf_ops_with_negative_number() { let data = TensorData::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.erf(); let expected = TensorData::from([ @@ -25,6 +25,6 @@ mod tests { [1.0000, 1.0000, 1.0000], ]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 2); } } diff --git a/crates/burn-tensor/src/tests/ops/exp.rs b/crates/burn-tensor/src/tests/ops/exp.rs index 9c3458346..cb3aaa793 100644 --- a/crates/burn-tensor/src/tests/ops/exp.rs +++ b/crates/burn-tensor/src/tests/ops/exp.rs @@ -6,11 +6,11 @@ mod tests { #[test] fn should_support_exp_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.exp(); let expected = TensorData::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 2); } } diff --git a/crates/burn-tensor/src/tests/ops/expand.rs b/crates/burn-tensor/src/tests/ops/expand.rs index 3ac85d17e..a034b2542 100644 --- a/crates/burn-tensor/src/tests/ops/expand.rs +++ b/crates/burn-tensor/src/tests/ops/expand.rs @@ -5,7 +5,7 @@ mod tests { #[test] fn expand_2d() { - let tensor = Tensor::::from_floats([1.0, 2.0, 3.0], &Default::default()); + let tensor = TestTensor::<1>::from_floats([1.0, 2.0, 3.0], &Default::default()); let output = tensor.expand([3, 3]); output.into_data().assert_eq( @@ -13,8 +13,7 @@ mod tests { false, ); - let tensor = - Tensor::::from_floats([4.0, 7.0, 2.0, 3.0], &Default::default()); + let tensor = TestTensor::<1>::from_floats([4.0, 7.0, 2.0, 3.0], &Default::default()); let output = tensor.expand([2, 4]); output.into_data().assert_eq( @@ -25,8 +24,7 @@ mod tests { #[test] fn expand_3d() { - let tensor = - Tensor::::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default()); + let tensor = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default()); let output = tensor.expand([3, 2, 2]); let expected = TensorData::from([ [[1.0, 2.0], [3.0, 4.0]], @@ -39,8 +37,7 @@ mod tests { #[test] fn expand_higher_dimensions() { - let tensor = - Tensor::::from_floats([[1.0, 2.0, 3.0, 4.0]], &Default::default()); + let tensor = TestTensor::<2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &Default::default()); let output = tensor.expand([2, 3, 4]); let expected = TensorData::from([ [ @@ -60,7 +57,7 @@ mod tests { #[test] fn broadcast_single() { - let tensor = Tensor::::from_floats([1.0], &Default::default()); + let tensor = TestTensor::<1>::from_floats([1.0], &Default::default()); let output = tensor.expand([2, 3]); output @@ -71,7 +68,7 @@ mod tests { #[test] #[should_panic] fn should_fail_expand_incompatible_shapes() { - let tensor = Tensor::::from_floats([1.0, 2.0, 3.0], &Default::default()); + let tensor = TestTensor::<1>::from_floats([1.0, 2.0, 3.0], &Default::default()); let _expanded_tensor = tensor.expand([2, 2]); } diff --git a/crates/burn-tensor/src/tests/ops/flatten.rs b/crates/burn-tensor/src/tests/ops/flatten.rs index 53cb15c0c..96547ebe8 100644 --- a/crates/burn-tensor/src/tests/ops/flatten.rs +++ b/crates/burn-tensor/src/tests/ops/flatten.rs @@ -6,7 +6,7 @@ mod tests { /// Test if the function can successfully flatten a 4D tensor to a 1D tensor. #[test] fn should_flatten_to_1d() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: Tensor = tensor.flatten(0, 3); let expected_shape = Shape::new([120]); assert_eq!(flattened_tensor.shape(), expected_shape); @@ -15,7 +15,7 @@ mod tests { /// Test if the function can successfully flatten the middle dimensions of a 4D tensor. #[test] fn should_flatten_middle() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: Tensor = tensor.flatten(1, 2); let expected_shape = Shape::new([2, 12, 5]); assert_eq!(flattened_tensor.shape(), expected_shape); @@ -24,7 +24,7 @@ mod tests { /// Test if the function can successfully flatten the first dimensions of a 4D tensor. #[test] fn should_flatten_begin() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: Tensor = tensor.flatten(0, 2); let expected_shape = Shape::new([24, 5]); assert_eq!(flattened_tensor.shape(), expected_shape); @@ -33,7 +33,7 @@ mod tests { /// Test if the function can successfully flatten the last dimensions of a 4D tensor. #[test] fn should_flatten_end() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: Tensor = tensor.flatten(1, 3); let expected_shape = Shape::new([2, 60]); assert_eq!(flattened_tensor.shape(), expected_shape); @@ -43,14 +43,14 @@ mod tests { #[test] #[should_panic] fn should_flatten_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: Tensor = tensor.flatten(2, 0); } #[test] #[should_panic] fn not_enough_destination_dimension() { - let tensor = Tensor::::ones(Shape::new([1, 5, 15]), &Default::default()); + let tensor = TestTensor::<3>::ones(Shape::new([1, 5, 15]), &Default::default()); let flattened_tensor: Tensor = tensor.flatten(1, 2); let expected_shape = Shape::new([75]); assert_eq!(flattened_tensor.shape(), expected_shape); diff --git a/crates/burn-tensor/src/tests/ops/flip.rs b/crates/burn-tensor/src/tests/ops/flip.rs index fd414a850..a60ec56cb 100644 --- a/crates/burn-tensor/src/tests/ops/flip.rs +++ b/crates/burn-tensor/src/tests/ops/flip.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn flip_int() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let flipped = tensor.clone().flip([0, 2]); // from pytorch: @@ -26,7 +26,7 @@ mod tests { #[test] fn flip_float() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); @@ -52,7 +52,7 @@ mod tests { #[test] fn flip_bool() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); @@ -84,7 +84,7 @@ mod tests { #[should_panic] fn flip_duplicated_axes() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a duplicated axis let _ = tensor.clone().flip([0, 0, 1]); @@ -94,7 +94,7 @@ mod tests { #[should_panic] fn flip_out_of_bound_axis() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with an out of bound axis let _ = tensor.clone().flip([3, 0, 1]); diff --git a/crates/burn-tensor/src/tests/ops/floor.rs b/crates/burn-tensor/src/tests/ops/floor.rs index 5913244e6..e681ca296 100644 --- a/crates/burn-tensor/src/tests/ops/floor.rs +++ b/crates/burn-tensor/src/tests/ops/floor.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_floor_ops() { let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.floor(); let expected = TensorData::from([[24., 87., 76.], [59., 43., 94.]]); diff --git a/crates/burn-tensor/src/tests/ops/full.rs b/crates/burn-tensor/src/tests/ops/full.rs index c8ac52e7e..eb977c2e2 100644 --- a/crates/burn-tensor/src/tests/ops/full.rs +++ b/crates/burn-tensor/src/tests/ops/full.rs @@ -14,14 +14,14 @@ mod tests { fn test_tensor_full() { let device = Default::default(); // Test full with f32 - let tensor = Tensor::::full([2, 3], 2.1, &device); + let tensor = TestTensor::<2>::full([2, 3], 2.1, &device); tensor .into_data() .assert_eq(&TensorData::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]), false); // Test full with Int - let int_tensor = Tensor::::full([2, 2], 2, &device); + let int_tensor = TestTensorInt::<2>::full([2, 2], 2, &device); int_tensor .into_data() @@ -29,11 +29,11 @@ mod tests { // TODO enable after adding support for bool // // Test full with bool - // let bool_tensor = Tensor::::full([2, 2], true, &device); + // let bool_tensor = TestTensorBool::<2>::full([2, 2], true, &device); // let data_expected = TensorData::from([[true, true], [true, true]]); // assert_eq!(data_expected, bool_tensor.into_data()); - // let bool_tensor = Tensor::::full([2, 2], false, &device); + // let bool_tensor = TestTensorBool::<2>::full([2, 2], false, &device); // let data_expected = TensorData::from([[false, false], [false, false]]); // assert_eq!(data_expected, bool_tensor.into_data()); } diff --git a/crates/burn-tensor/src/tests/ops/init.rs b/crates/burn-tensor/src/tests/ops/init.rs index b0bb019a7..8d862b5de 100644 --- a/crates/burn-tensor/src/tests/ops/init.rs +++ b/crates/burn-tensor/src/tests/ops/init.rs @@ -6,21 +6,21 @@ mod tests { #[test] fn should_support_float_empty() { let shape = [2, 2]; - let tensor = Tensor::::empty(shape, &Default::default()); + let tensor = TestTensor::<2>::empty(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()) } #[test] fn should_support_int_empty() { let shape = [2, 2]; - let tensor = Tensor::::empty(shape, &Default::default()); + let tensor = TestTensorInt::<2>::empty(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()) } #[test] fn should_support_float_zeros() { let shape = [2, 2]; - let tensor = Tensor::::zeros(shape, &Default::default()); + let tensor = TestTensor::<2>::zeros(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor @@ -31,7 +31,7 @@ mod tests { #[test] fn should_support_int_zeros() { let shape = [2, 2]; - let tensor = Tensor::::zeros(shape, &Default::default()); + let tensor = TestTensorInt::<2>::zeros(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor @@ -42,7 +42,7 @@ mod tests { #[test] fn should_support_float_ones() { let shape = [2, 2]; - let tensor = Tensor::::ones(shape, &Default::default()); + let tensor = TestTensor::<2>::ones(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor @@ -53,7 +53,7 @@ mod tests { #[test] fn should_support_int_ones() { let shape = [2, 2]; - let tensor = Tensor::::ones(shape, &Default::default()); + let tensor = TestTensorInt::<2>::ones(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor @@ -64,7 +64,7 @@ mod tests { #[test] fn should_support_bool_empty() { let shape = [2, 2]; - let tensor = Tensor::::empty(shape, &Default::default()); + let tensor = TestTensorBool::<2>::empty(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()) } } diff --git a/crates/burn-tensor/src/tests/ops/iter_dim.rs b/crates/burn-tensor/src/tests/ops/iter_dim.rs index d3215d1a7..6514f3fe9 100644 --- a/crates/burn-tensor/src/tests/ops/iter_dim.rs +++ b/crates/burn-tensor/src/tests/ops/iter_dim.rs @@ -7,9 +7,9 @@ mod test { fn test_1d_iter_last_item() { let data = [1, 2, 3, 4]; let device = Default::default(); - let tensor = Tensor::::from_ints(data, &device); + let tensor = TestTensorInt::<1>::from_ints(data, &device); assert_eq!( - Tensor::::from_ints([4], &device).into_data(), + TestTensorInt::<1>::from_ints([4], &device).into_data(), tensor.iter_dim(0).last().unwrap().into_data() ) } @@ -17,7 +17,7 @@ mod test { #[test] #[should_panic] fn test_too_high_dimension() { - Tensor::::zeros([10], &Default::default()).iter_dim(1); + TestTensor::<1>::zeros([10], &Default::default()).iter_dim(1); } #[test] @@ -27,7 +27,7 @@ mod test { [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]; - let tensor = Tensor::::from_floats(data, &Default::default()); + let tensor = TestTensor::<2>::from_floats(data, &Default::default()); let lhs = tensor.clone().slice([1..2, 0..5]); let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap(); assert_eq!( @@ -46,7 +46,7 @@ mod test { [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]; 5]; - let tensor = Tensor::::from_floats(data, &Default::default()); + let tensor = TestTensor::<3>::from_floats(data, &Default::default()); let lhs = tensor.iter_dim(2).nth(1).unwrap(); let rhs = TensorData::from([2., 5., 8.]); assert_eq!( @@ -60,8 +60,8 @@ mod test { #[test] fn test_iter_dim_double_end() { - let input = Tensor::::arange(0..(4 * 6 * 3), &Default::default()) - .reshape([4, 6, 3]); + let input = + TestTensorInt::<1>::arange(0..(4 * 6 * 3), &Default::default()).reshape([4, 6, 3]); let mut iter = input.iter_dim(1); let ele0 = TensorData::from([[[0, 1, 2]], [[18, 19, 20]], [[36, 37, 38]], [[54, 55, 56]]]); @@ -104,8 +104,8 @@ mod test { #[test] fn test_iter_dim_single_element() { - let input = Tensor::::arange(0..(4 * 1 * 3), &Default::default()) - .reshape([4, 1, 3]); + let input = + TestTensorInt::<1>::arange(0..(4 * 1 * 3), &Default::default()).reshape([4, 1, 3]); let mut iter = input.clone().iter_dim(1); iter.next() diff --git a/crates/burn-tensor/src/tests/ops/log.rs b/crates/burn-tensor/src/tests/ops/log.rs index 5d6c59c24..28477f513 100644 --- a/crates/burn-tensor/src/tests/ops/log.rs +++ b/crates/burn-tensor/src/tests/ops/log.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_log_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.log(); let expected = TensorData::from([ @@ -14,6 +14,6 @@ mod tests { [1.0986, 1.3862, 1.6094], ]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 2); } } diff --git a/crates/burn-tensor/src/tests/ops/log1p.rs b/crates/burn-tensor/src/tests/ops/log1p.rs index 008fb97dc..b042f7f63 100644 --- a/crates/burn-tensor/src/tests/ops/log1p.rs +++ b/crates/burn-tensor/src/tests/ops/log1p.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_exp_log1p() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.log1p(); let expected = TensorData::from([ @@ -14,6 +14,6 @@ mod tests { [1.3862, 1.6094, 1.7917], ]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 2); } } diff --git a/crates/burn-tensor/src/tests/ops/map_comparison.rs b/crates/burn-tensor/src/tests/ops/map_comparison.rs index def083f6b..f6e5eef58 100644 --- a/crates/burn-tensor/src/tests/ops/map_comparison.rs +++ b/crates/burn-tensor/src/tests/ops/map_comparison.rs @@ -133,8 +133,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [f32::INFINITY, 4.0, f32::NEG_INFINITY]]); let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); @@ -149,8 +149,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, f32::INFINITY, 5.0]]); let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); @@ -375,8 +375,8 @@ mod tests { let data_1 = TensorData::from([[false, true, true], [true, false, true]]); let data_2 = TensorData::from([[false, false, true], [false, true, true]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); @@ -391,8 +391,8 @@ mod tests { let data_1 = TensorData::from([[false, true, true], [true, false, true]]); let data_2 = TensorData::from([[false, false, true], [false, true, true]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); @@ -405,7 +405,7 @@ mod tests { #[test] fn should_support_bool_not() { let data_1 = TensorData::from([[false, true, true], [true, true, false]]); - let tensor_1 = Tensor::::from_data(data_1, &Default::default()); + let tensor_1 = TestTensorBool::<2>::from_data(data_1, &Default::default()); let data_actual_cloned = tensor_1.clone().bool_not(); let data_actual_inplace = tensor_1.bool_not(); diff --git a/crates/burn-tensor/src/tests/ops/mask.rs b/crates/burn-tensor/src/tests/ops/mask.rs index 64d88c617..80fd46e17 100644 --- a/crates/burn-tensor/src/tests/ops/mask.rs +++ b/crates/burn-tensor/src/tests/ops/mask.rs @@ -7,14 +7,11 @@ mod tests { fn should_support_mask_where_ops() { let device = Default::default(); let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device); - let mask = Tensor::::from_bool( + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); - let value = Tensor::::from_data( - TensorData::from([[1.8, 2.8], [3.8, 4.8]]), - &device, - ); + let value = TestTensor::<2>::from_data(TensorData::from([[1.8, 2.8], [3.8, 4.8]]), &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([[1.8, 7.0], [2.0, 4.8]]); @@ -26,8 +23,8 @@ mod tests { fn should_support_mask_where_broadcast_int() { let device = Default::default(); // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times - let tensor = Tensor::::arange(2..6, &device).reshape([1, 2, 2]); - let mask = Tensor::::from_bool( + let tensor = TestTensorInt::<1>::arange(2..6, &device).reshape([1, 2, 2]); + let mask = TestTensorBool::<3>::from_bool( TensorData::from([ [[true, false], [false, true]], [[false, true], [true, false]], @@ -36,7 +33,7 @@ mod tests { ]), &device, ); - let value = Tensor::::ones([4, 2, 2], &device); + let value = TestTensorInt::<3>::ones([4, 2, 2], &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([ @@ -53,8 +50,8 @@ mod tests { fn should_support_mask_where_broadcast() { let device = Default::default(); // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times - let tensor = Tensor::::arange(2..6, &device).reshape([1, 2, 2]); - let mask = Tensor::::from_bool( + let tensor = TestTensorInt::<1>::arange(2..6, &device).reshape([1, 2, 2]); + let mask = TestTensorBool::<3>::from_bool( TensorData::from([ [[true, false], [false, true]], [[false, true], [true, false]], @@ -63,7 +60,7 @@ mod tests { ]), &device, ); - let value = Tensor::::ones([4, 2, 2], &device); + let value = TestTensor::<3>::ones([4, 2, 2], &device); let output = tensor.float().mask_where(mask, value); let expected = TensorData::from([ @@ -87,7 +84,7 @@ mod tests { ], &device, ); - let mask = Tensor::::from_bool( + let mask = TestTensorBool::<2>::from_bool( TensorData::from([ [true, true, true], [true, true, false], @@ -95,7 +92,7 @@ mod tests { ]), &device, ); - let value = Tensor::::from_data( + let value = TestTensor::<2>::from_data( TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]), &device, ); @@ -107,14 +104,14 @@ mod tests { [f32::NAN, f32::NAN, f32::NAN], ]); - output.into_data().assert_eq(&expected, false); + output.into_data().assert_approx_eq(&expected, 5); } #[test] fn should_support_mask_fill_ops() { let device = Default::default(); let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device); - let mask = Tensor::::from_bool( + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); @@ -128,13 +125,12 @@ mod tests { #[test] fn should_support_int_mask_where_ops() { let device = Default::default(); - let tensor = Tensor::::from_data([[1, 7], [2, 3]], &device); - let mask = Tensor::::from_bool( + let tensor = TestTensorInt::<2>::from_data([[1, 7], [2, 3]], &device); + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); - let value = - Tensor::::from_data(TensorData::from([[8, 9], [10, 11]]), &device); + let value = TestTensorInt::<2>::from_data(TensorData::from([[8, 9], [10, 11]]), &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([[8, 7], [2, 11]]); @@ -145,8 +141,8 @@ mod tests { #[test] fn should_support_int_mask_fill_ops() { let device = Default::default(); - let tensor = Tensor::::from_data([[1, 7], [2, 3]], &device); - let mask = Tensor::::from_bool( + let tensor = TestTensorInt::<2>::from_data([[1, 7], [2, 3]], &device); + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); @@ -160,14 +156,14 @@ mod tests { #[test] fn float_mask_fill_infinite() { let device = Default::default(); - let tensor = Tensor::::from_data( + let tensor = TestTensor::<2>::from_data( [ [f32::NEG_INFINITY, f32::NEG_INFINITY], [f32::NEG_INFINITY, f32::NEG_INFINITY], ], &device, ); - let mask = Tensor::::from_bool( + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); diff --git a/crates/burn-tensor/src/tests/ops/matmul.rs b/crates/burn-tensor/src/tests/ops/matmul.rs index b897def7a..0e666c864 100644 --- a/crates/burn-tensor/src/tests/ops/matmul.rs +++ b/crates/burn-tensor/src/tests/ops/matmul.rs @@ -102,20 +102,20 @@ mod tests { fn test_matmul_trivial() { let device = Default::default(); - let tensor_1 = Tensor::::arange(0..16, &device) + let tensor_1 = TestTensorInt::<1>::arange(0..16, &device) .reshape([4, 4]) .float(); let tensor_3 = tensor_1.clone().matmul(tensor_1); - tensor_3.into_data().assert_eq( + tensor_3.into_data().assert_approx_eq( &TensorData::from([ [56., 62., 68., 74.], [152., 174., 196., 218.], [248., 286., 324., 362.], [344., 398., 452., 506.], ]), - false, + 3, ); } @@ -123,20 +123,20 @@ mod tests { fn test_matmul_trivial_transposed() { let device = Default::default(); - let tensor_1 = Tensor::::arange(0..16, &device) + let tensor_1 = TestTensorInt::<1>::arange(0..16, &device) .reshape([4, 4]) .float(); let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); - tensor_3.into_data().assert_eq( + tensor_3.into_data().assert_approx_eq( &TensorData::from([ [14., 38., 62., 86.], [38., 126., 214., 302.], [62., 214., 366., 518.], [86., 302., 518., 734.], ]), - false, + 3, ); } @@ -144,20 +144,20 @@ mod tests { fn test_matmul_4_8() { let device = Default::default(); - let tensor_1 = Tensor::::arange(0..32, &device) + let tensor_1 = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); - tensor_3.into_data().assert_eq( + tensor_3.into_data().assert_approx_eq( &TensorData::from([ [140., 364., 588., 812.], [364., 1100., 1836., 2572.], [588., 1836., 3084., 4332.], [812., 2572., 4332., 6092.], ]), - false, + 4, ); } diff --git a/crates/burn-tensor/src/tests/ops/movedim.rs b/crates/burn-tensor/src/tests/ops/movedim.rs index 35aa15ebb..24855926e 100644 --- a/crates/burn-tensor/src/tests/ops/movedim.rs +++ b/crates/burn-tensor/src/tests/ops/movedim.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn movedim_int() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let permuted = tensor.clone().movedim(0, 2); // from pytorch: @@ -31,7 +31,7 @@ mod tests { #[test] fn movedim_float() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); @@ -58,7 +58,7 @@ mod tests { #[test] fn movedim_bool() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); @@ -85,7 +85,7 @@ mod tests { #[test] fn vec_input_int() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]); // from pytorch @@ -110,7 +110,7 @@ mod tests { #[test] fn vec_input_float() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); @@ -137,7 +137,7 @@ mod tests { #[test] fn vec_input_bool() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); @@ -164,7 +164,7 @@ mod tests { #[test] fn different_input_types() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); @@ -192,7 +192,7 @@ mod tests { #[should_panic] fn edge_different_sizes() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().movedim(vec![0, 1], vec![0]); @@ -202,7 +202,7 @@ mod tests { #[should_panic] fn edge_out_of_bound_axis() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with an out of bound axis let _ = tensor.clone().movedim(0, 100); @@ -212,7 +212,7 @@ mod tests { #[should_panic] fn edge_vec_is_not_a_set() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().movedim(vec![0, 1, 1, 1, 1], vec![0, 0, 1]); @@ -222,7 +222,7 @@ mod tests { #[should_panic] fn edge_out_of_bound_axis_vec() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with an out of bound axis let _ = tensor.clone().movedim(vec![0, 100], vec![0, 1]); diff --git a/crates/burn-tensor/src/tests/ops/mul.rs b/crates/burn-tensor/src/tests/ops/mul.rs index adeaa68f5..1f18f251d 100644 --- a/crates/burn-tensor/src/tests/ops/mul.rs +++ b/crates/burn-tensor/src/tests/ops/mul.rs @@ -8,8 +8,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_2 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); @@ -22,8 +22,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); @@ -34,10 +34,8 @@ mod tests { #[test] fn test_mul_broadcast_2_dims() { let device = Default::default(); - let tensor_1 = - Tensor::::from_data([0.0, 1.0, 2.0], &device).reshape([3, 1]); - let tensor_2 = - Tensor::::from_data([3.0, 4.0, 5.0], &device).reshape([1, 3]); + let tensor_1 = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device).reshape([3, 1]); + let tensor_2 = TestTensor::<1>::from_data([3.0, 4.0, 5.0], &device).reshape([1, 3]); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]); @@ -49,7 +47,7 @@ mod tests { fn should_support_mul_scalar_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor * scalar; let expected = TensorData::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); @@ -62,8 +60,8 @@ mod tests { let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let data_2 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0, 1, 4], [9, 16, 25]]); @@ -76,8 +74,8 @@ mod tests { let data_1 = TensorData::from([[0, 1, 2]]); let data_2 = TensorData::from([[3, 4, 5], [6, 7, 8]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0, 4, 10], [0, 7, 16]]); @@ -89,7 +87,7 @@ mod tests { fn should_support_mul_scalar_ops_int() { let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let scalar = 2; - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor * scalar; let expected = TensorData::from([[0, 2, 4], [6, 8, 10]]); diff --git a/crates/burn-tensor/src/tests/ops/neg.rs b/crates/burn-tensor/src/tests/ops/neg.rs index 3be7e78ad..87725ada9 100644 --- a/crates/burn-tensor/src/tests/ops/neg.rs +++ b/crates/burn-tensor/src/tests/ops/neg.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_neg_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.neg(); let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::(); diff --git a/crates/burn-tensor/src/tests/ops/padding.rs b/crates/burn-tensor/src/tests/ops/padding.rs index 19b55fd92..2b551658c 100644 --- a/crates/burn-tensor/src/tests/ops/padding.rs +++ b/crates/burn-tensor/src/tests/ops/padding.rs @@ -1,23 +1,28 @@ #[burn_tensor_testgen::testgen(padding)] mod tests { use super::*; - use burn_tensor::{backend::Backend, Int, Numeric, Shape, Tensor, TensorData}; + use burn_tensor::{ + as_type, + backend::Backend, + tests::{Float as _, Int as _}, + Numeric, Shape, Tensor, TensorData, + }; #[test] fn padding_2d_test() { let unpadded_floats: [[f32; 3]; 2] = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]; let tensor = TestTensor::<2>::from(unpadded_floats); - let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1); + let padded_tensor = tensor.pad((2, 2, 2, 2), FloatType::new(1.1)); - let expected = TensorData::from([ + let expected = TensorData::from(as_type!(FloatType: [ [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1], [1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], - ]); + ])); padded_tensor.into_data().assert_eq(&expected, false); } @@ -26,9 +31,9 @@ mod tests { let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; let tensor = TestTensor::<4>::from(unpadded_floats); - let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1); + let padded_tensor = tensor.pad((2, 2, 2, 2), FloatType::new(1.1)); - let expected = TensorData::from([[[ + let expected = TensorData::from(as_type!(FloatType: [[[ [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 0.0, 1.0, 1.1, 1.1], @@ -36,7 +41,7 @@ mod tests { [1.1, 1.1, 4.0, 5.0, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], - ]]]); + ]]])); padded_tensor.into_data().assert_eq(&expected, false); } @@ -45,9 +50,9 @@ mod tests { let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; let tensor = TestTensor::<4>::from(unpadded_floats); - let padded_tensor = tensor.pad((2, 1, 4, 3), 1.1); + let padded_tensor = tensor.pad((2, 1, 4, 3), FloatType::new(1.1)); - let expected = TensorData::from([[[ + let expected = TensorData::from(as_type!(FloatType: [[[ [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], @@ -58,7 +63,7 @@ mod tests { [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], - ]]]); + ]]])); padded_tensor.into_data().assert_eq(&expected, false); } @@ -67,7 +72,7 @@ mod tests { let unpadded_ints = [[[[0, 1], [2, 3], [4, 5]]]]; let tensor = TestTensorInt::<4>::from(unpadded_ints); - let padded_tensor = tensor.pad((2, 1, 4, 3), 6); + let padded_tensor = tensor.pad((2, 1, 4, 3), IntType::new(6)); let padded_primitive_data_expected = [[[ [6, 6, 6, 6, 6], @@ -81,7 +86,7 @@ mod tests { [6, 6, 6, 6, 6], [6, 6, 6, 6, 6], ]]]; - let expected = TensorData::from([[[ + let expected = TensorData::from(as_type!(IntType: [[[ [6, 6, 6, 6, 6], [6, 6, 6, 6, 6], [6, 6, 6, 6, 6], @@ -92,7 +97,7 @@ mod tests { [6, 6, 6, 6, 6], [6, 6, 6, 6, 6], [6, 6, 6, 6, 6], - ]]]); + ]]])); padded_tensor.into_data().assert_eq(&expected, false); } } diff --git a/crates/burn-tensor/src/tests/ops/permute.rs b/crates/burn-tensor/src/tests/ops/permute.rs index 898a38d8c..5f5384ca2 100644 --- a/crates/burn-tensor/src/tests/ops/permute.rs +++ b/crates/burn-tensor/src/tests/ops/permute.rs @@ -7,7 +7,7 @@ mod tests { #[test] fn permute_int() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let permuted = tensor.clone().permute([2, 1, 0]); @@ -34,7 +34,7 @@ mod tests { #[test] fn permute_float() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); @@ -63,7 +63,7 @@ mod tests { #[test] fn permute_bool() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device) + let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); @@ -93,7 +93,7 @@ mod tests { #[should_panic] fn edge_repeated_axes() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().permute([0, 0, 1]); @@ -103,7 +103,7 @@ mod tests { #[should_panic] fn edge_out_of_bound_axis() { let device = Default::default(); - let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().permute([3, 0, 1]); diff --git a/crates/burn-tensor/src/tests/ops/powf.rs b/crates/burn-tensor/src/tests/ops/powf.rs index 4997b1de5..4f1d537e0 100644 --- a/crates/burn-tensor/src/tests/ops/powf.rs +++ b/crates/burn-tensor/src/tests/ops/powf.rs @@ -6,22 +6,22 @@ mod tests { #[test] fn should_support_powf_ops() { let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]); - let tensor_pow = Tensor::::from_data(pow, &Default::default()); + let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq_diff(&expected, 0.1); } #[test] fn should_support_neg_power() { let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[-0.95, -0.67, -0.45], [-0.24, -0.5, -0.6]]); - let tensor_pow = Tensor::::from_data(pow, &Default::default()); + let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1., 1., 0.73204285], [0.76822936, 0.5, 0.38073079]]); @@ -32,43 +32,45 @@ mod tests { #[test] fn should_support_neg_values_with_even_power() { let data = TensorData::from([[1.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[2.0, 2.0, 4.0], [4.0, 4.0, 2.0]]); - let tensor_pow = Tensor::::from_data(pow, &Default::default()); + let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, 1.0, 16.0], [81.0, 256.0, 25.0]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 1); } #[test] fn should_support_neg_values_with_odd_power() { let data = TensorData::from([[1.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]); - let tensor_pow = Tensor::::from_data(pow, &Default::default()); + let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq_diff(&expected, 0.5); } #[test] fn should_support_powf_broadcasted() { let device = Default::default(); - let tensor_1 = Tensor::::from_floats([2.0, 3.0, 4.0], &device); + let tensor_1 = TestTensor::<1>::from_floats([2.0, 3.0, 4.0], &device); let tensor_2 = Tensor::from_floats([1.0], &device); // Broadcast rhs let output = tensor_1.clone().powf(tensor_2.clone()); - output.into_data().assert_approx_eq(&tensor_1.to_data(), 3); + output + .into_data() + .assert_approx_eq_diff(&tensor_1.to_data(), 0.004); // Broadcast lhs let output = tensor_2.powf(tensor_1); output .into_data() - .assert_approx_eq(&TensorData::from([1.0, 1.0, 1.0]), 3); + .assert_approx_eq_diff(&TensorData::from([1.0, 1.0, 1.0]), 0.004); } } diff --git a/crates/burn-tensor/src/tests/ops/powf_scalar.rs b/crates/burn-tensor/src/tests/ops/powf_scalar.rs index cd1db0ac0..6a5f8bf09 100644 --- a/crates/burn-tensor/src/tests/ops/powf_scalar.rs +++ b/crates/burn-tensor/src/tests/ops/powf_scalar.rs @@ -6,18 +6,18 @@ mod tests { #[test] fn should_support_powf_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(0.71); let expected = TensorData::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq_diff(&expected, 0.04); } #[test] fn should_support_neg_power() { let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(-0.33); let expected = @@ -29,7 +29,7 @@ mod tests { #[test] fn should_support_neg_values_with_even_power() { let data = TensorData::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(4.0); let expected = TensorData::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]); @@ -40,11 +40,11 @@ mod tests { #[test] fn should_support_neg_values_with_odd_power() { let data = TensorData::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(3.0); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq_diff(&expected, 0.5); } } diff --git a/crates/burn-tensor/src/tests/ops/random.rs b/crates/burn-tensor/src/tests/ops/random.rs index f00eddd54..da29536ea 100644 --- a/crates/burn-tensor/src/tests/ops/random.rs +++ b/crates/burn-tensor/src/tests/ops/random.rs @@ -1,12 +1,11 @@ #[burn_tensor_testgen::testgen(random)] mod tests { use super::*; - use burn_tensor::{Distribution, Tensor}; + use burn_tensor::{tests::Float, Distribution, Tensor}; #[test] fn rand_default() { - let tensor = - Tensor::::random([20], Distribution::Default, &Default::default()); + let tensor = TestTensor::<1>::random([20], Distribution::Default, &Default::default()); // check that the tensor is within the range of [0..1) (1 is exclusive) tensor.into_data().assert_within_range(0.0..1.0); @@ -14,23 +13,17 @@ mod tests { #[test] fn rand_uniform() { - let tensor = Tensor::::random( - [20], - Distribution::Uniform(4., 5.), - &Default::default(), - ); + let tensor = + TestTensor::<1>::random([20], Distribution::Uniform(4., 5.), &Default::default()); tensor.into_data().assert_within_range(4.0..5.0); } #[test] fn rand_bernoulli() { - let tensor = Tensor::::random( - [20], - Distribution::Bernoulli(1.), - &Default::default(), - ); + let tensor = + TestTensor::<1>::random([20], Distribution::Bernoulli(1.), &Default::default()); - assert_eq!(tensor.into_data(), [1f32; 20].into()); + assert_eq!(tensor.into_data(), [FloatType::new(1f32); 20].into()); } } diff --git a/crates/burn-tensor/src/tests/ops/recip.rs b/crates/burn-tensor/src/tests/ops/recip.rs index 0dba873ce..920e548ed 100644 --- a/crates/burn-tensor/src/tests/ops/recip.rs +++ b/crates/burn-tensor/src/tests/ops/recip.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_recip_ops() { let data = TensorData::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.recip(); let expected = TensorData::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); diff --git a/crates/burn-tensor/src/tests/ops/remainder.rs b/crates/burn-tensor/src/tests/ops/remainder.rs index 3c0f68f5d..5cdc99c4a 100644 --- a/crates/burn-tensor/src/tests/ops/remainder.rs +++ b/crates/burn-tensor/src/tests/ops/remainder.rs @@ -8,7 +8,7 @@ mod tests { fn should_support_remainder_basic() { let data = TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(2.0); let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); @@ -21,7 +21,7 @@ mod tests { fn should_support_remainder_float() { let data = TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(-1.5); let expected = TensorData::from([-0.5, -1.0, 0.0, -0.5, -1.0]); @@ -33,7 +33,7 @@ mod tests { fn should_be_zero() { let data = TensorData::from([0.0, 0.0, 0.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(3.5); let expected = TensorData::from([0.0, 0.0, 0.0]); @@ -45,7 +45,7 @@ mod tests { fn should_have_no_remainder() { let data = TensorData::from([-4.0, 4.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(4.0); let expected = TensorData::from([-0.0, 0.0]); @@ -57,7 +57,7 @@ mod tests { fn should_be_negative() { let data = TensorData::from([-7.0, -3.0, 2.0, 6.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(-2.5); let expected = TensorData::from([-2.0, -0.50, -0.50, -1.5]); @@ -69,7 +69,7 @@ mod tests { fn should_support_fp_dividends() { let data = TensorData::from([-7.5, -2.5, 2.5, 7.5]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(3.0); let expected = TensorData::from([1.5, 0.5, 2.5, 1.5]); @@ -81,7 +81,7 @@ mod tests { fn should_support_large_divisor() { let data = TensorData::from([-1.0, 1.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(10.0); let expected = TensorData::from([9.0, 1.0]); @@ -93,7 +93,7 @@ mod tests { fn should_support_remainder_op() { let data = TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor % 2.0; let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); diff --git a/crates/burn-tensor/src/tests/ops/repeat.rs b/crates/burn-tensor/src/tests/ops/repeat.rs index 253f8507c..84d3ea438 100644 --- a/crates/burn-tensor/src/tests/ops/repeat.rs +++ b/crates/burn-tensor/src/tests/ops/repeat.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_repeat_ops_one_dimension() { let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.repeat(&[4, 1, 1]); let expected = TensorData::from([ @@ -22,7 +22,7 @@ mod tests { #[test] fn should_support_bool_repeat_ops_one_dimension() { let data = TensorData::from([[true, false, false]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorBool::<2>::from_data(data, &Default::default()); let output = tensor.repeat(&[4, 1, 1]); let expected = TensorData::from([ @@ -37,7 +37,7 @@ mod tests { #[test] fn should_support_int_repeat_ops_one_dimension() { let data = TensorData::from([[0i32, 1i32, 2i32]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor.repeat(&[4, 1, 1]); let expected = TensorData::from([ @@ -58,7 +58,7 @@ mod tests { [[9.0f32, 10.0f32], [11.0f32, 12.0f32]], [[13.0f32, 14.0f32], [15.0f32, 16.0f32]], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); let output = tensor.repeat(&[2, 3, 2]); let expected = TensorData::from([ @@ -139,7 +139,7 @@ mod tests { [[9i32, 10i32], [11i32, 12i32]], [[13i32, 14i32], [15i32, 16i32]], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<3>::from_data(data, &Default::default()); let output = tensor.repeat(&[2, 3, 2]); @@ -219,7 +219,7 @@ mod tests { [[false, true], [true, false]], [[true, true], [false, false]], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorBool::<3>::from_data(data, &Default::default()); let output = tensor.repeat(&[2, 3, 2]); let expected = TensorData::from([ diff --git a/crates/burn-tensor/src/tests/ops/repeat_dim.rs b/crates/burn-tensor/src/tests/ops/repeat_dim.rs index dd2008968..f1fefc621 100644 --- a/crates/burn-tensor/src/tests/ops/repeat_dim.rs +++ b/crates/burn-tensor/src/tests/ops/repeat_dim.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_repeat_ops() { let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([ @@ -22,7 +22,7 @@ mod tests { #[test] fn should_support_bool_repeat_ops() { let data = TensorData::from([[true, false, false]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorBool::<2>::from_data(data, &Default::default()); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([ @@ -37,7 +37,7 @@ mod tests { #[test] fn should_support_int_repeat_ops() { let data = TensorData::from([[0, 1, 2]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]); @@ -53,7 +53,7 @@ mod tests { [[9.0f32, 10.0f32], [11.0f32, 12.0f32]], [[13.0f32, 14.0f32], [15.0f32, 16.0f32]], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); let output = tensor.repeat_dim(2, 2); let expected = TensorData::from([ @@ -86,7 +86,7 @@ mod tests { [[9i32, 10i32], [11i32, 12i32]], [[13i32, 14i32], [15i32, 16i32]], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<3>::from_data(data, &Default::default()); let output = tensor.repeat_dim(2, 3); let expected = TensorData::from([ @@ -117,7 +117,7 @@ mod tests { [[false, true], [true, false]], [[true, true], [false, false]], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorBool::<3>::from_data(data, &Default::default()); let output = tensor.repeat_dim(1, 2); let expected = TensorData::from([ diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index e78ed70ef..bd9ffbf86 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_reshape_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.clone().reshape([1, 3]); let expected = TensorData::from([[0.0, 1.0, 2.0]]); @@ -17,7 +17,7 @@ mod tests { #[test] fn should_support_reshape_int() { let data = TensorData::from([0, 1, 2]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<1>::from_data(data, &Default::default()); let output = tensor.clone().reshape([1, 3]); let expected = TensorData::from([[0, 1, 2]]); @@ -28,7 +28,7 @@ mod tests { #[test] fn should_support_reshape_bool() { let data = TensorData::from([false, true, false]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorBool::<1>::from_data(data, &Default::default()); let output = tensor.clone().reshape([1, 3]); let expected = TensorData::from([[false, true, false]]); @@ -39,7 +39,7 @@ mod tests { #[test] fn should_support_reshape_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.clone().reshape([6]); let expected = TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); @@ -55,7 +55,7 @@ mod tests { [6.0, 7.0, 8.0], [9.0, 10.0, 11.0], ]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); // Infer the dimension via -1 let reshaped = tensor.clone().reshape([2, -1]); @@ -76,12 +76,12 @@ mod tests { #[test] fn should_not_corrupt_after_slice() { - let zeros = Tensor::::zeros([2], &Default::default()); + let zeros = TestTensor::<1>::zeros([2], &Default::default()); zeros.clone().slice([1..2]).reshape([1]).exp(); // May lead to zeroes being equal to [0.0, 1.0] zeros.into_data().assert_eq( - &Tensor::::zeros([2], &Default::default()).to_data(), + &TestTensor::<1>::zeros([2], &Default::default()).to_data(), true, ); } @@ -90,7 +90,7 @@ mod tests { #[should_panic] fn multiple_neg_ones() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); let data_actual = tensor.reshape([-1, -1]).into_data(); } @@ -98,7 +98,7 @@ mod tests { #[should_panic] fn neg_value() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); let data_actual = tensor.reshape([-2, -1]).into_data(); } } diff --git a/crates/burn-tensor/src/tests/ops/round.rs b/crates/burn-tensor/src/tests/ops/round.rs index 01e108951..fb6bd2030 100644 --- a/crates/burn-tensor/src/tests/ops/round.rs +++ b/crates/burn-tensor/src/tests/ops/round.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_round_ops() { let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.round(); let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]); @@ -14,7 +14,7 @@ mod tests { output.into_data().assert_approx_eq(&expected, 3); let data = TensorData::from([1.5, 2.5, 3.5, 4.5, 5.5, 6.5]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.round(); let expected = TensorData::from([2., 2., 4., 4., 6., 6.]); diff --git a/crates/burn-tensor/src/tests/ops/sin.rs b/crates/burn-tensor/src/tests/ops/sin.rs index e025162ad..8cb54da60 100644 --- a/crates/burn-tensor/src/tests/ops/sin.rs +++ b/crates/burn-tensor/src/tests/ops/sin.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_sin_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.sin(); let expected = TensorData::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); diff --git a/crates/burn-tensor/src/tests/ops/slice.rs b/crates/burn-tensor/src/tests/ops/slice.rs index 63c55a93b..61725a506 100644 --- a/crates/burn-tensor/src/tests/ops/slice.rs +++ b/crates/burn-tensor/src/tests/ops/slice.rs @@ -1,12 +1,12 @@ #[burn_tensor_testgen::testgen(slice)] mod tests { use super::*; - use burn_tensor::{Int, Tensor, TensorData}; + use burn_tensor::{as_type, Int, Tensor, TensorData}; #[test] fn should_support_full_sliceing_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([0..3]); @@ -16,7 +16,7 @@ mod tests { #[test] fn should_support_partial_sliceing_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.slice([1..3]); let expected = TensorData::from([1.0, 2.0]); @@ -27,7 +27,7 @@ mod tests { #[test] fn should_support_full_sliceing_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); let output = tensor.clone().slice([0..2]); output.into_data().assert_eq(&data, false); @@ -39,7 +39,7 @@ mod tests { #[test] fn should_support_partial_sliceing_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.slice([0..2, 0..2]); let expected = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); @@ -85,8 +85,8 @@ mod tests { let data_assigned = TensorData::from([10.0, 5.0]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); - let tensor_assigned = Tensor::::from_data(data_assigned, &device); + let tensor = TestTensor::<1>::from_data(data, &device); + let tensor_assigned = TestTensor::<1>::from_data(data_assigned, &device); let output = tensor.slice_assign([0..2], tensor_assigned); let expected = TensorData::from([10.0, 5.0, 2.0]); @@ -100,8 +100,8 @@ mod tests { let data_assigned = TensorData::from([[10.0, 5.0]]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); - let tensor_assigned = Tensor::::from_data(data_assigned, &device); + let tensor = TestTensor::<2>::from_data(data, &device); + let tensor_assigned = TestTensor::<2>::from_data(data_assigned, &device); let output = tensor.slice_assign([1..2, 0..2], tensor_assigned); let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); @@ -111,7 +111,7 @@ mod tests { #[test] fn slice_should_not_corrupt_potentially_inplace_operations() { - let tensor = Tensor::::from_data([1, 2, 3, 4, 5], &Default::default()); + let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &Default::default()); let tensor = tensor.clone().slice([0..3]) + tensor.clone().slice([2..5]); let expected = TensorData::from([4, 6, 8]); @@ -122,8 +122,8 @@ mod tests { #[test] fn slice_assign_should_not_corrupt_potentially_inplace_operations() { let device = Default::default(); - let tensor = Tensor::::from_data([1, 2, 3, 4, 5], &device); - let values = Tensor::::from_data([10, 20, 30], &device); + let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &device); + let values = TestTensorInt::<1>::from_data([10, 20, 30], &device); let tensor_1 = tensor.clone().slice_assign([0..3], values); let tensor_2 = tensor + 2; @@ -138,8 +138,8 @@ mod tests { #[test] fn clamp_when_slice_exceeds_dimension() { - let data = TensorData::from([0.0f32, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let data = TensorData::from(as_type!(FloatType: [0.0f32, 1.0, 2.0])); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([0..4]); output.into_data().assert_eq(&data, true); @@ -147,8 +147,8 @@ mod tests { #[test] fn negative_dimensions() { - let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let data = TensorData::from(as_type!(FloatType: [[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]])); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); // Clamping to the tensor dimensions let output = tensor.clone().slice([(0, 4), (0, 4)]); @@ -156,7 +156,7 @@ mod tests { // Negative dimensions let output = tensor.clone().slice([(0, 1), (0, 1)]); - let data = TensorData::from([[0.0f32]]); + let data = TensorData::from(as_type!(FloatType: [[0.0f32]])); output.into_data().assert_eq(&data, true); let output = tensor.slice([(0, -1), (0, -2)]); @@ -165,29 +165,29 @@ mod tests { #[test] fn missing_dimensions() { - let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let data = TensorData::from(as_type!(FloatType: [[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]])); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); // Clamping to the tensor dimensions let output = tensor.clone().slice([Some((0, 4)), Some((0, 4))]); output.into_data().assert_eq(&data, true); // Negative dimensions - let data = TensorData::from([[0.0f32]]); + let data = TensorData::from(as_type!(FloatType: [[0.0f32]])); let output = tensor.clone().slice([Some((0, -1)), Some((0, -2))]); output.into_data().assert_eq(&data, true); // Missing dimensions let output = tensor.clone().slice([Some((0, 1)), None]); - let data = TensorData::from([[0.0f32, 1.0, 2.0]]); + let data = TensorData::from(as_type!(FloatType: [[0.0f32, 1.0, 2.0]])); output.into_data().assert_eq(&data, true); let output = tensor.clone().slice([None, Some((0, 2))]); - let data = TensorData::from([[0.0f32, 1.0], [3.0, 4.0]]); + let data = TensorData::from(as_type!(FloatType: [[0.0f32, 1.0], [3.0, 4.0]])); output.into_data().assert_eq(&data, true); let output = tensor.clone().slice([None, None]); - let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data = TensorData::from(as_type!(FloatType: [[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]])); output.into_data().assert_eq(&data, true); } @@ -204,7 +204,7 @@ mod tests { #[should_panic] fn should_panic_when_slice_with_too_many_dimensions() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([0..1, 0..1]); @@ -215,7 +215,7 @@ mod tests { #[should_panic] fn should_panic_when_slice_is_desc() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); #[allow(clippy::reversed_empty_ranges)] let output = tensor.slice([2..1]); @@ -227,7 +227,7 @@ mod tests { #[should_panic] fn should_panic_when_slice_is_equal() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([1..1]); diff --git a/crates/burn-tensor/src/tests/ops/sort_argsort.rs b/crates/burn-tensor/src/tests/ops/sort_argsort.rs index 7e2c7ae33..9b2a9236d 100644 --- a/crates/burn-tensor/src/tests/ops/sort_argsort.rs +++ b/crates/burn-tensor/src/tests/ops/sort_argsort.rs @@ -5,6 +5,11 @@ mod tests { #[test] fn test_sort_1d_int() { + // Skip with u8 + if (IntType::MAX as u32) < 1000u32 { + return; + } + let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, 2, 8, -10, 42, 1000]); // Sort along dim=0 @@ -16,6 +21,11 @@ mod tests { #[test] fn test_argsort_1d_int() { + // Skip with u8 + if (IntType::MAX as u32) < 1000u32 { + return; + } + let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]); // Sort along dim=0 @@ -27,17 +37,20 @@ mod tests { #[test] fn test_sort_with_indices_descending_int() { - // 1D - let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]); + // Skip with u8 + if (IntType::MAX as u32) >= 1000u32 { + // 1D + let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]); - // Sort along dim=0 - let (values, indices) = tensor.sort_descending_with_indices(0); + // Sort along dim=0 + let (values, indices) = tensor.sort_descending_with_indices(0); - let values_expected = TensorData::from([1000, 42, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -10]); - values.into_data().assert_eq(&values_expected, false); + let values_expected = TensorData::from([1000, 42, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -10]); + values.into_data().assert_eq(&values_expected, false); - let indices_expected = TensorData::from([12, 11, 8, 9, 2, 5, 4, 1, 6, 3, 0, 7, 10]); - indices.into_data().assert_eq(&indices_expected, false); + let indices_expected = TensorData::from([12, 11, 8, 9, 2, 5, 4, 1, 6, 3, 0, 7, 10]); + indices.into_data().assert_eq(&indices_expected, false); + } // 2D let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); @@ -142,7 +155,9 @@ mod tests { let values_expected = TensorData::from([ -8.1, -0.3, -0.21, 0., 0.5, 0.94, 0.99, 1.2, 2.1, 2.3, 3., 4., 199.412, ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.04); } #[test] @@ -171,7 +186,9 @@ mod tests { let values_expected = TensorData::from([ 199.412, 4., 3., 2.3, 2.1, 1.2, 0.99, 0.94, 0.5, 0., -0.21, -0.3, -8.1, ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.04); let indices_expected = TensorData::from([8, 9, 11, 7, 4, 1, 10, 5, 0, 3, 2, 6, 12]); indices.into_data().assert_eq(&indices_expected, false); @@ -189,7 +206,9 @@ mod tests { [[0., 2.1, 0.94], [-0.5, 1.2, -0.21]], [[0.99, 3., 4.], [-0.3, 2.3, -8.1]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.04); let indices_expected = TensorData::from([[[1, 1, 1], [0, 0, 0]], [[1, 1, 0], [0, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); @@ -209,7 +228,9 @@ mod tests { [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.002); // Sort along dim=1 let values = tensor.clone().sort(1); @@ -218,7 +239,9 @@ mod tests { [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.002); // Sort along dim=2 let values = tensor.sort(2); @@ -227,7 +250,9 @@ mod tests { [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.002); } #[test] @@ -243,7 +268,9 @@ mod tests { [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.002); let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); @@ -255,7 +282,9 @@ mod tests { [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.002); let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); @@ -267,7 +296,9 @@ mod tests { [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], ]); - values.into_data().assert_approx_eq(&values_expected, 5); + values + .into_data() + .assert_approx_eq_diff(&values_expected, 0.002); let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); @@ -307,7 +338,7 @@ mod tests { let values = tensor.sort(0); let values_expected = TensorData::from([[-0.5, 0.94], [-0.3, f32::NAN], [0., f32::NAN]]); - values.into_data().assert_approx_eq(&values_expected, 5); + values.into_data().assert_approx_eq(&values_expected, 4); } #[test] diff --git a/crates/burn-tensor/src/tests/ops/sqrt.rs b/crates/burn-tensor/src/tests/ops/sqrt.rs index df88803a1..e30e2e6cb 100644 --- a/crates/burn-tensor/src/tests/ops/sqrt.rs +++ b/crates/burn-tensor/src/tests/ops/sqrt.rs @@ -7,11 +7,11 @@ mod tests { #[test] fn should_support_sqrt_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.sqrt(); let expected = TensorData::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq_diff(&expected, 0.002); } } diff --git a/crates/burn-tensor/src/tests/ops/squeeze.rs b/crates/burn-tensor/src/tests/ops/squeeze.rs index dde6fdd48..a04205165 100644 --- a/crates/burn-tensor/src/tests/ops/squeeze.rs +++ b/crates/burn-tensor/src/tests/ops/squeeze.rs @@ -6,7 +6,7 @@ mod tests { /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. #[test] fn should_squeeze() { - let tensor = Tensor::::ones(Shape::new([2, 1, 4]), &Default::default()); + let tensor = TestTensor::<3>::ones(Shape::new([2, 1, 4]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze(1); let expected_shape = Shape::new([2, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -14,7 +14,7 @@ mod tests { /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. #[test] fn should_squeeze_first() { - let tensor = Tensor::::ones(Shape::new([1, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 4, 5]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze(0); let expected_shape = Shape::new([3, 4, 5]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -22,7 +22,7 @@ mod tests { /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. #[test] fn should_squeeze_last() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 1]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 1]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze(3); let expected_shape = Shape::new([2, 3, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -31,14 +31,14 @@ mod tests { #[test] #[should_panic] fn should_squeeze_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze(2); } /// Test if the function works with an empty slice #[test] fn should_squeeze_dims_with_empty_slice() { - let tensor = Tensor::::ones(Shape::new([1, 1, 3]), &Default::default()); + let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 3]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze_dims(&[]); let expected_shape = Shape::new([3]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -47,7 +47,7 @@ mod tests { /// Test if the function works with positive indices #[test] fn should_squeeze_dims_with_positive_indices() { - let tensor = Tensor::::ones(Shape::new([1, 3, 1, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze_dims(&[0, 2]); let expected_shape = Shape::new([3, 5]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -56,7 +56,7 @@ mod tests { /// Test if the function works with negative indices #[test] fn should_squeeze_dims_with_negative_indices() { - let tensor = Tensor::::ones(Shape::new([2, 1, 3, 1]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 1, 3, 1]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze_dims(&[-3, -1]); let expected_shape = Shape::new([2, 3]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -66,7 +66,7 @@ mod tests { #[test] #[should_panic] fn should_squeeze_dims_work_if_non_singleton() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4]), &Default::default()); + let tensor = TestTensor::<3>::ones(Shape::new([2, 3, 4]), &Default::default()); let squeezed_tensor: Tensor = tensor.squeeze_dims(&[1]); let expected_shape = Shape::new([2, 3, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); @@ -76,7 +76,7 @@ mod tests { #[test] #[should_panic] fn should_squeeze_dims_panic_on_too_many_dimensions() { - let tensor = Tensor::::ones(Shape::new([1, 1, 1]), &Default::default()); + let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default()); let _: Tensor = tensor.squeeze_dims(&[0, 1, 2]); } @@ -84,14 +84,14 @@ mod tests { #[test] #[should_panic] fn should_squeeze_dims_dimension_mismatch_panic() { - let tensor = Tensor::::ones(Shape::new([1, 3, 1, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default()); let _: Tensor = tensor.squeeze_dims(&[0, 2]); } /// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor. #[test] fn should_unsqueeze_dim() { - let tensor = Tensor::::ones(Shape::new([2, 4, 1]), &Default::default()); + let tensor = TestTensor::<3>::ones(Shape::new([2, 4, 1]), &Default::default()); let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(1); let expected_shape = Shape::new([2, 1, 4, 1]); assert_eq!(unsqueezed_tensor.shape(), expected_shape); @@ -100,7 +100,7 @@ mod tests { /// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor. #[test] fn should_unsqueeze_dim_first() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(0); let expected_shape = Shape::new([1, 2, 3, 4, 5]); assert_eq!(unsqueezed_tensor.shape(), expected_shape); @@ -109,7 +109,7 @@ mod tests { /// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor. #[test] fn should_unsqueeze_dim_last() { - let tensor = Tensor::::ones(Shape::new([5, 4, 3, 2]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([5, 4, 3, 2]), &Default::default()); let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(4); let expected_shape = Shape::new([5, 4, 3, 2, 1]); assert_eq!(unsqueezed_tensor.shape(), expected_shape); @@ -119,14 +119,13 @@ mod tests { #[test] #[should_panic] fn should_unsqueeze_dim_panic() { - let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5]), &Default::default()); + let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(5); } #[test] fn should_unsqueeze_dims_support_dim_inference() { - let input_tensor = - Tensor::::ones(Shape::new([3, 4, 5]), &Default::default()); + let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor = input_tensor.unsqueeze_dims::<5>(&[1, -2]); let expected_shape = Shape::new([3, 1, 4, 1, 5]); assert_eq!(output_tensor.shape(), expected_shape); @@ -134,8 +133,7 @@ mod tests { #[test] fn should_unsqueeze_dims_handle_first_last() { - let input_tensor = - Tensor::::ones(Shape::new([3, 4, 5]), &Default::default()); + let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor = input_tensor.unsqueeze_dims::<5>(&[0, 4]); let expected_shape = Shape::new([1, 3, 4, 5, 1]); assert_eq!(output_tensor.shape(), expected_shape); @@ -144,8 +142,7 @@ mod tests { #[test] fn should_unsqueeze_dims_work_with_single_dim() { //bruh, just call unsqueeze_dim - let input_tensor = - Tensor::::ones(Shape::new([3, 4, 5]), &Default::default()); + let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor: Tensor = input_tensor.unsqueeze_dims(&[1]); let expected_shape = Shape::new([3, 1, 4, 5]); assert_eq!(output_tensor.shape(), expected_shape); @@ -154,8 +151,7 @@ mod tests { #[test] #[should_panic] fn should_unsqueeze_dims_panic() { - let input_tensor = - Tensor::::ones(Shape::new([3, 4, 5]), &Default::default()); + let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor: Tensor = input_tensor.unsqueeze_dims(&[0, -6]); } } diff --git a/crates/burn-tensor/src/tests/ops/stack.rs b/crates/burn-tensor/src/tests/ops/stack.rs index c8a6582b7..970132f40 100644 --- a/crates/burn-tensor/src/tests/ops/stack.rs +++ b/crates/burn-tensor/src/tests/ops/stack.rs @@ -19,8 +19,8 @@ mod tests { #[test] fn should_support_stack_ops_int() { let device = Default::default(); - let tensor_1 = Tensor::::from_data([[1, 2, 3]], &device); - let tensor_2 = Tensor::::from_data([[4, 5, 6]], &device); + let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device); + let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1, 2, 3]], [[4, 5, 6]]]); @@ -31,8 +31,8 @@ mod tests { #[test] fn should_support_stack_ops_bool() { let device = Default::default(); - let tensor_1 = Tensor::::from_data([[false, true, true]], &device); - let tensor_2 = Tensor::::from_data([[true, true, false]], &device); + let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device); + let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[false, true, true]], [[true, true, false]]]); @@ -101,7 +101,7 @@ mod tests { #[test] fn should_generate_row_major_layout() { let device = Default::default(); - let tensor = Tensor::::arange(1..25, &device).reshape([4, 6]); + let tensor = TestTensorInt::<1>::arange(1..25, &device).reshape([4, 6]); let zeros: Tensor = Tensor::zeros([4, 6], &device); let intersperse = Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]); diff --git a/crates/burn-tensor/src/tests/ops/sub.rs b/crates/burn-tensor/src/tests/ops/sub.rs index 823f769a5..4765cb637 100644 --- a/crates/burn-tensor/src/tests/ops/sub.rs +++ b/crates/burn-tensor/src/tests/ops/sub.rs @@ -8,8 +8,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_2 = TensorData::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); @@ -22,8 +22,8 @@ mod tests { let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensor::<2>::from_data(data_1, &device); + let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); @@ -35,7 +35,7 @@ mod tests { fn should_support_sub_scalar_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor - scalar; let expected = TensorData::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); @@ -48,8 +48,8 @@ mod tests { let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let data_2 = TensorData::from([[6, 7, 8], [9, 10, 11]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-6, -6, -6], [-6, -6, -6]]); @@ -62,8 +62,8 @@ mod tests { let data_1 = TensorData::from([[0, 1, 2]]); let data_2 = TensorData::from([[3, 4, 5], [6, 7, 8]]); let device = Default::default(); - let tensor_1 = Tensor::::from_data(data_1, &device); - let tensor_2 = Tensor::::from_data(data_2, &device); + let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); + let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-3, -3, -3], [-6, -6, -6]]); @@ -75,7 +75,7 @@ mod tests { fn should_support_sub_scalar_ops_int() { let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let scalar = 2; - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor - scalar; let expected = TensorData::from([[-2, -1, 0], [1, 2, 3]]); diff --git a/crates/burn-tensor/src/tests/ops/tanh.rs b/crates/burn-tensor/src/tests/ops/tanh.rs index 49e2d8128..84b7e558d 100644 --- a/crates/burn-tensor/src/tests/ops/tanh.rs +++ b/crates/burn-tensor/src/tests/ops/tanh.rs @@ -6,7 +6,7 @@ mod tests { #[test] fn should_support_tanh_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.tanh(); let expected = TensorData::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); diff --git a/crates/burn-tensor/src/tests/ops/transpose.rs b/crates/burn-tensor/src/tests/ops/transpose.rs index 84ae67dde..a88c68451 100644 --- a/crates/burn-tensor/src/tests/ops/transpose.rs +++ b/crates/burn-tensor/src/tests/ops/transpose.rs @@ -44,7 +44,7 @@ mod tests { #[test] fn should_support_transpose_ops_int() { - let tensor = Tensor::::from_data( + let tensor = TestTensorInt::<3>::from_data( [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], &Default::default(), ); @@ -57,7 +57,7 @@ mod tests { #[test] fn should_support_swap_dims_int() { - let tensor = Tensor::::from_data( + let tensor = TestTensorInt::<3>::from_data( [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], &Default::default(), ); @@ -70,7 +70,7 @@ mod tests { #[test] fn should_support_transpose_bool() { - let tensor = Tensor::::from_data( + let tensor = TestTensorBool::<3>::from_data( [ [[false, true, false], [false, false, false]], [[false, false, true], [false, false, true]], @@ -89,7 +89,7 @@ mod tests { #[test] fn should_support_swap_dims_bool() { - let tensor = Tensor::::from_data( + let tensor = TestTensorBool::<3>::from_data( [ [[false, true, false], [false, false, false]], [[false, false, true], [false, false, true]], diff --git a/crates/burn-tensor/src/tests/ops/tri_mask.rs b/crates/burn-tensor/src/tests/ops/tri_mask.rs index 869ac9aba..26a04a7a6 100644 --- a/crates/burn-tensor/src/tests/ops/tri_mask.rs +++ b/crates/burn-tensor/src/tests/ops/tri_mask.rs @@ -11,7 +11,7 @@ mod tests { [true, false, true], [true, true, false], ]); - let tensor = Tensor::::diag_mask([3, 3], 0, &device); + let tensor = TestTensorBool::<2>::diag_mask([3, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, true); } @@ -20,7 +20,7 @@ mod tests { let device = Default::default(); let data_expected = TensorData::from([[true, false, true], [true, true, false], [true, true, true]]); - let tensor = Tensor::::diag_mask([3, 3], 1, &device); + let tensor = TestTensorBool::<2>::diag_mask([3, 3], 1, &device); tensor.into_data().assert_eq(&data_expected, true); } @@ -32,7 +32,7 @@ mod tests { [true, false, false], [true, true, false], ]); - let tensor = Tensor::::triu_mask([3, 3], 0, &device); + let tensor = TestTensorBool::<2>::triu_mask([3, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, true); } @@ -44,7 +44,7 @@ mod tests { [true, true, false], [true, true, true], ]); - let tensor = Tensor::::triu_mask([3, 3], 1, &device); + let tensor = TestTensorBool::<2>::triu_mask([3, 3], 1, &device); tensor.into_data().assert_eq(&data_expected, true); } @@ -57,7 +57,7 @@ mod tests { [false, false, true], [false, false, false], ]); - let tensor = Tensor::::tril_mask([3, 3], 0, &device); + let tensor = TestTensorBool::<2>::tril_mask([3, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, true); } @@ -70,7 +70,7 @@ mod tests { [false, true, true], [false, false, true], ]); - let tensor = Tensor::::tril_mask([3, 3], -1, &device); + let tensor = TestTensorBool::<2>::tril_mask([3, 3], -1, &device); tensor.into_data().assert_eq(&data_expected, true); } @@ -82,7 +82,7 @@ mod tests { [true, false, true, true], [true, true, false, true], ]); - let tensor = Tensor::::diag_mask([3, 4], 0, &device); + let tensor = TestTensorBool::<2>::diag_mask([3, 4], 0, &device); tensor.into_data().assert_eq(&data_expected, true); let data_expected = TensorData::from([ @@ -91,7 +91,7 @@ mod tests { [true, true, false], [true, true, true], ]); - let tensor = Tensor::::diag_mask([4, 3], 0, &device); + let tensor = TestTensorBool::<2>::diag_mask([4, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, true); } } diff --git a/crates/burn-tensor/src/tests/quantization/calibration.rs b/crates/burn-tensor/src/tests/quantization/calibration.rs index 2dbee1d36..8140be4f2 100644 --- a/crates/burn-tensor/src/tests/quantization/calibration.rs +++ b/crates/burn-tensor/src/tests/quantization/calibration.rs @@ -8,8 +8,7 @@ mod tests { #[test] fn min_max_calibration_range() { - let tensor = - Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default()); + let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default()); let calibration = MinMaxCalibration {}; let range = calibration.compute_range(&tensor); diff --git a/crates/burn-tensor/src/tests/quantization/ops/mask.rs b/crates/burn-tensor/src/tests/quantization/ops/mask.rs index 1ec47dc30..49bcf57a8 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/mask.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/mask.rs @@ -14,7 +14,7 @@ mod tests { QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), ); let tensor = TestTensor::<2>::from_data(data, &device); - let mask = Tensor::::from_bool( + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); @@ -46,7 +46,7 @@ mod tests { QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), ); let tensor = TestTensor::<2>::from_data(data, &device); - let mask = Tensor::::from_bool( + let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &device, ); diff --git a/crates/burn-tensor/src/tests/quantization/ops/powf.rs b/crates/burn-tensor/src/tests/quantization/ops/powf.rs index 79252ec6f..b948a9b39 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/powf.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/powf.rs @@ -71,7 +71,7 @@ mod tests { [2, 3], QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), ); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<2>::from_data(data, &device); // Quantized [[4.0, 2.0, 4.0], [2.0, 4.0, 2.0]] (with range [2., 5.] to reduce quantization errors) let data = TensorData::quantized( vec![76i8, -26, 76, -26, 76, -26], @@ -99,7 +99,7 @@ mod tests { [2, 3], QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), ); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); // Quantized [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]] let data = TensorData::quantized( vec![127i8, 127, 127, 127, 127, 127], diff --git a/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs b/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs index fe2b9ed67..a7cc9c108 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs @@ -55,7 +55,7 @@ mod tests { [2, 3], QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), ); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(2.0); let expected = TensorData::from([[0., 1., 4.], [9., 16., 25.]]); @@ -75,7 +75,7 @@ mod tests { [2, 3], QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), ); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(3.0); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 438b973f0..72834b8c8 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -11,7 +11,7 @@ mod tests { #[test] fn should_support_quantize_affine_int8() { let device = Default::default(); - let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); let qparams = QuantizationParameters { scale: Tensor::from_floats([0.009_019_608], &device), @@ -32,7 +32,7 @@ mod tests { #[test] fn should_support_quantize_symmetric_int8() { let device = Default::default(); - let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); let qparams = QuantizationParameters { scale: Tensor::from_floats([0.014_173_228], &device), @@ -63,7 +63,7 @@ mod tests { 0.014_173_228, )), ); - let x_q = Tensor::::from_data(data, &device); + let x_q = TestTensor::<1>::from_data(data, &device); let x = x_q.dequantize(); @@ -77,7 +77,7 @@ mod tests { let device = Default::default(); // NOTE: we use fully representable values since different backend implementations could differ slightly // due to rounding discrepancies - let tensor = Tensor::::from_floats([5., 0., 4., -10.], &device); + let tensor = TestTensor::<1>::from_floats([5., 0., 4., -10.], &device); let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); let x_q = tensor.quantize_dynamic(&scheme); diff --git a/crates/burn-tensor/src/tests/quantization/ops/reshape.rs b/crates/burn-tensor/src/tests/quantization/ops/reshape.rs index 6d42ee56a..e34c68ff1 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/reshape.rs @@ -78,7 +78,7 @@ mod tests { // May lead to zeroes being equal to [0.0, 1.0] zeros.dequantize().into_data().assert_eq( - &Tensor::::zeros([2], &Default::default()).to_data(), + &TestTensor::<1>::zeros([2], &Default::default()).to_data(), true, ); } diff --git a/crates/burn-tensor/src/tests/quantization/ops/slice.rs b/crates/burn-tensor/src/tests/quantization/ops/slice.rs index 5beedb676..77186b3c4 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/slice.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/slice.rs @@ -120,7 +120,7 @@ mod tests { [2], QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.039215688, -128)), ); - let tensor_assigned = Tensor::::from_data(data, &device); + let tensor_assigned = TestTensor::<1>::from_data(data, &device); let output = tensor.slice_assign([0..2], tensor_assigned); let expected = TensorData::from([10.0, 5.0, 2.0]); @@ -148,7 +148,7 @@ mod tests { [1, 2], QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.039215688, -128)), ); - let tensor_assigned = Tensor::::from_data(data, &device); + let tensor_assigned = TestTensor::<2>::from_data(data, &device); let output = tensor.slice_assign([1..2, 0..2], tensor_assigned); let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); @@ -291,7 +291,7 @@ mod tests { #[should_panic] fn should_panic_when_slice_with_too_many_dimensions() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([0..1, 0..1]); @@ -302,7 +302,7 @@ mod tests { #[should_panic] fn should_panic_when_slice_is_desc() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); #[allow(clippy::reversed_empty_ranges)] let output = tensor.slice([2..1]); @@ -314,7 +314,7 @@ mod tests { #[should_panic] fn should_panic_when_slice_is_equal() { let data = TensorData::from([0.0, 1.0, 2.0]); - let tensor = Tensor::::from_data(data.clone(), &Default::default()); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([1..1]); diff --git a/crates/burn-tensor/src/tests/quantization/scheme.rs b/crates/burn-tensor/src/tests/quantization/scheme.rs index e8e23050c..a305d25f4 100644 --- a/crates/burn-tensor/src/tests/quantization/scheme.rs +++ b/crates/burn-tensor/src/tests/quantization/scheme.rs @@ -11,8 +11,8 @@ mod tests { let device = Default::default(); let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); let range = CalibrationRange { - min: Tensor::::from_floats([-1.8], &device), - max: Tensor::::from_floats([0.5], &device), + min: TestTensor::<1>::from_floats([-1.8], &device), + max: TestTensor::<1>::from_floats([0.5], &device), }; let qparams = scheme.compute_q_params(range); @@ -33,8 +33,8 @@ mod tests { let device = Default::default(); let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); let range = CalibrationRange { - min: Tensor::::from_floats([-1.8], &device), - max: Tensor::::from_floats([0.5], &device), + min: TestTensor::<1>::from_floats([-1.8], &device), + max: TestTensor::<1>::from_floats([0.5], &device), }; let qparams = scheme.compute_q_params(range); diff --git a/crates/burn-tensor/src/tests/stats/cov.rs b/crates/burn-tensor/src/tests/stats/cov.rs index 8eadb1b24..35c8352dd 100644 --- a/crates/burn-tensor/src/tests/stats/cov.rs +++ b/crates/burn-tensor/src/tests/stats/cov.rs @@ -10,7 +10,7 @@ mod tests { #[test] fn test_cov_1() { let data = TensorData::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cov(1, 1); let expected = @@ -22,7 +22,7 @@ mod tests { #[test] fn test_cov_4() { let data = TensorData::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cov(1, 0); let expected = @@ -34,7 +34,7 @@ mod tests { #[test] fn test_cov_2() { let data = TensorData::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]); - let tensor = Tensor::::from_data(data, &Default::default()); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cov(1, 1); let expected = TensorData::from([ @@ -45,7 +45,7 @@ mod tests { ]) .convert::(); - output.into_data().assert_approx_eq(&expected, 3); + output.into_data().assert_approx_eq(&expected, 2); } #[test] @@ -57,9 +57,9 @@ mod tests { [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], ]); let device = Default::default(); - let tensor = Tensor::::from_data(data, &device); + let tensor = TestTensor::<3>::from_data(data, &device); let data_actual = tensor.cov(0, 1).into_data(); - let data_expected = Tensor::::zeros([4, 4, 4], &device).to_data(); + let data_expected = TestTensor::<3>::zeros([4, 4, 4], &device).to_data(); data_expected.assert_approx_eq(&data_actual, 3); } } diff --git a/crates/burn-tensor/src/tests/stats/display.rs b/crates/burn-tensor/src/tests/stats/display.rs index e2b2e7c73..72a7ca9c0 100644 --- a/crates/burn-tensor/src/tests/stats/display.rs +++ b/crates/burn-tensor/src/tests/stats/display.rs @@ -283,7 +283,7 @@ mod tests { } #[test] fn test_display_precision() { - let tensor = Tensor::::full([1, 1], 0.123456789, &Default::default()); + let tensor = TestTensor::<2>::full([1, 1], 0.123456789, &Default::default()); let output = format!("{}", tensor); let expected = format!( @@ -308,7 +308,7 @@ mod tests { // }; // set_print_options(print_options); - let tensor = Tensor::::full([3, 2], 0.123456789, &Default::default()); + let tensor = TestTensor::<2>::full([3, 2], 0.123456789, &Default::default()); // Set precision to 3 let output = format!("{:.3}", tensor); diff --git a/crates/burn-tensor/src/tests/stats/eye.rs b/crates/burn-tensor/src/tests/stats/eye.rs index 738d22d04..b3bc9c934 100644 --- a/crates/burn-tensor/src/tests/stats/eye.rs +++ b/crates/burn-tensor/src/tests/stats/eye.rs @@ -8,14 +8,14 @@ mod tests { fn test_eye_float() { let device = Default::default(); let tensor = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); - let rhs = Tensor::::eye(3, &device); + let rhs = TestTensor::<2>::eye(3, &device); assert_eq!(tensor.to_data(), rhs.to_data()); } fn test_eye_int() { let device = Default::default(); let tensor = TestTensorInt::<2>::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]); - let rhs = Tensor::::eye(3, &device); + let rhs = TestTensorInt::<2>::eye(3, &device); assert_eq!(tensor.to_data(), rhs.to_data()); } } diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 1e94d1f97..ea629af3f 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -35,6 +35,8 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ "export_tests", ] } +half = { workspace = true } +paste = { workspace = true } [package.metadata.docs.rs] features = ["default"] diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 28e54379c..7c26dcc31 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -11,6 +11,7 @@ pub use burn_jit::{ pub use burn_jit::{tensor::JitTensor, JitBackend}; pub use burn_jit::{FloatElement, IntElement}; +pub use cubecl::flex32; pub use cubecl::ir::CubeDim; pub use cubecl::wgpu::*; @@ -95,7 +96,14 @@ pub type Wgpu = JitBackend; - burn_jit::testgen_all!(); + // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it + // breaks a lot of tests from precision issues + #[cfg(feature = "spirv")] + burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64]); + #[cfg(not(feature = "spirv"))] + burn_jit::testgen_all!([f32], [i32]); }