Add more type support for burn-jit (#2454)

This commit is contained in:
Genna Wingert 2024-11-04 19:01:01 +01:00 committed by GitHub
parent 5597657314
commit 42f39f16b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
151 changed files with 2367 additions and 1557 deletions

181
Cargo.lock generated
View File

@ -313,7 +313,7 @@ dependencies = [
"clap 4.5.20",
"colored",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"dirs 5.0.1",
"github-device-flow",
"half",
@ -321,7 +321,7 @@ dependencies = [
"os_info",
"percent-encoding",
"rand",
"reqwest 0.12.8",
"reqwest 0.12.9",
"rstest",
"serde",
"serde_json",
@ -503,7 +503,7 @@ dependencies = [
"burn-common",
"burn-tensor",
"burn-tensor-testgen",
"derive-new",
"derive-new 0.7.0",
"log",
"spin",
]
@ -516,7 +516,7 @@ dependencies = [
"burn-tch",
"burn-tensor",
"candle-core",
"derive-new",
"derive-new 0.7.0",
"half",
]
@ -529,7 +529,7 @@ dependencies = [
"getrandom",
"indicatif",
"rayon",
"reqwest 0.12.8",
"reqwest 0.12.9",
"tokio",
"web-time",
]
@ -552,7 +552,7 @@ dependencies = [
"burn-tensor",
"burn-wgpu",
"data-encoding",
"derive-new",
"derive-new 0.7.0",
"flate2",
"half",
"hashbrown 0.15.0",
@ -579,9 +579,10 @@ dependencies = [
"burn-tensor",
"bytemuck",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"half",
"log",
"paste",
]
[[package]]
@ -590,7 +591,7 @@ version = "0.16.0"
dependencies = [
"burn-common",
"csv",
"derive-new",
"derive-new 0.7.0",
"dirs 5.0.1",
"fake",
"flate2",
@ -620,7 +621,7 @@ dependencies = [
name = "burn-derive"
version = "0.16.0"
dependencies = [
"derive-new",
"derive-new 0.7.0",
"proc-macro2",
"quote",
"syn 2.0.87",
@ -632,7 +633,7 @@ version = "0.16.0"
dependencies = [
"burn-common",
"burn-tensor",
"derive-new",
"derive-new 0.7.0",
"half",
"hashbrown 0.15.0",
"log",
@ -649,7 +650,7 @@ dependencies = [
"burn-tensor",
"bytemuck",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"half",
"log",
]
@ -660,7 +661,7 @@ version = "0.16.0"
dependencies = [
"burn",
"candle-core",
"derive-new",
"derive-new 0.7.0",
"half",
"log",
"onnx-ir",
@ -691,12 +692,13 @@ dependencies = [
"burn-tensor-testgen",
"bytemuck",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"futures-lite",
"half",
"hashbrown 0.15.0",
"log",
"num-traits",
"paste",
"rand",
"serde",
"serial_test",
@ -713,7 +715,7 @@ dependencies = [
"burn-autodiff",
"burn-common",
"burn-tensor",
"derive-new",
"derive-new 0.7.0",
"libm",
"matrixmultiply",
"ndarray 0.16.1",
@ -767,7 +769,7 @@ dependencies = [
"bytemuck",
"colored",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"half",
"hashbrown 0.15.0",
"num-traits",
@ -792,7 +794,7 @@ version = "0.16.0"
dependencies = [
"burn-core",
"burn-ndarray",
"derive-new",
"derive-new 0.7.0",
"log",
"nvml-wrapper",
"ratatui",
@ -812,6 +814,8 @@ dependencies = [
"burn-jit",
"burn-tensor",
"cubecl",
"half",
"paste",
]
[[package]]
@ -881,15 +885,15 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.6.0"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
checksum = "7e1a39b963e261c58017edf2007e5b63425ad21538aaaf51fe23d1da41703701"
dependencies = [
"accelerate-src",
"byteorder",
"candle-kernels",
"candle-metal-kernels",
"cudarc 0.11.9",
"cudarc",
"gemm",
"half",
"libc",
@ -908,18 +912,18 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.6.0"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
checksum = "539cbfbf2d1d68a6ed97115e579c77c98f8ed0cfe7edbc6d7d30d2ac0c9e3d50"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-metal-kernels"
version = "0.6.0"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f889aacd02fd999620a0435133d7cf3b58c81ef9dd5e47c38939b7a72345ea86"
checksum = "166a92826d615d98b205e346e52128fa0439f2ab3302587403fdc558b4219e19"
dependencies = [
"metal 0.27.0",
"once_cell",
@ -1460,7 +1464,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1476,7 +1480,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4"
dependencies = [
"derive-new",
"derive-new 0.6.0",
"embassy-futures",
"futures-lite",
"getrandom",
@ -1491,9 +1495,9 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"derive-new",
"derive-new 0.6.0",
"embassy-futures",
"futures-lite",
"getrandom",
@ -1508,13 +1512,14 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
"cubecl-macros",
"cubecl-runtime 0.4.0",
"derive-new",
"derive-new 0.6.0",
"derive_more 1.0.0",
"half",
"log",
"num-traits",
@ -1525,13 +1530,13 @@ dependencies = [
[[package]]
name = "cubecl-cpp"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
"cubecl-core",
"cubecl-runtime 0.4.0",
"derive-new",
"derive-new 0.6.0",
"half",
"log",
]
@ -1539,15 +1544,15 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
"cubecl-core",
"cubecl-cpp",
"cubecl-runtime 0.4.0",
"cudarc 0.12.1",
"derive-new",
"cudarc",
"derive-new 0.6.0",
"half",
"log",
]
@ -1555,7 +1560,7 @@ dependencies = [
[[package]]
name = "cubecl-hip"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
@ -1563,7 +1568,7 @@ dependencies = [
"cubecl-cpp",
"cubecl-hip-sys",
"cubecl-runtime 0.4.0",
"derive-new",
"derive-new 0.6.0",
"half",
"log",
]
@ -1580,7 +1585,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1591,11 +1596,11 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"cubecl-common 0.4.0",
"darling",
"derive-new",
"derive-new 0.6.0",
"ident_case",
"prettyplease",
"proc-macro2",
@ -1606,7 +1611,7 @@ dependencies = [
[[package]]
name = "cubecl-opt"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"cubecl-common 0.4.0",
"cubecl-core",
@ -1628,7 +1633,7 @@ dependencies = [
"async-lock",
"cfg_aliases 0.2.1",
"cubecl-common 0.3.0",
"derive-new",
"derive-new 0.6.0",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
@ -1643,13 +1648,13 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"async-channel",
"async-lock",
"cfg_aliases 0.2.1",
"cubecl-common 0.4.0",
"derive-new",
"derive-new 0.6.0",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
@ -1664,12 +1669,13 @@ dependencies = [
[[package]]
name = "cubecl-spirv"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"cubecl-common 0.4.0",
"cubecl-core",
"cubecl-opt",
"cubecl-runtime 0.4.0",
"half",
"hashbrown 0.14.5",
"rspirv",
]
@ -1677,7 +1683,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
dependencies = [
"ash",
"async-channel",
@ -1688,29 +1694,20 @@ dependencies = [
"cubecl-core",
"cubecl-runtime 0.4.0",
"cubecl-spirv",
"derive-new",
"derive-new 0.6.0",
"hashbrown 0.14.5",
"log",
"web-time",
"wgpu",
]
[[package]]
name = "cudarc"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bd4d1eee570c3b2ac64ed114125517dd1e541d88dd28fc259f1de4dba8d60"
dependencies = [
"half",
"libloading",
]
[[package]]
name = "cudarc"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
dependencies = [
"half",
"libloading",
]
@ -1720,7 +1717,7 @@ version = "0.16.0"
dependencies = [
"burn",
"csv",
"reqwest 0.12.8",
"reqwest 0.12.9",
"serde",
]
@ -1732,7 +1729,7 @@ dependencies = [
"burn-jit",
"bytemuck",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"log",
"serde",
]
@ -1752,7 +1749,7 @@ version = "0.16.0"
dependencies = [
"burn",
"bytemuck",
"derive-new",
"derive-new 0.7.0",
"guide",
"log",
"serde",
@ -1764,7 +1761,7 @@ version = "0.16.0"
dependencies = [
"burn",
"bytemuck",
"derive-new",
"derive-new 0.7.0",
"guide",
"log",
"serde",
@ -1777,7 +1774,7 @@ dependencies = [
"burn",
"bytemuck",
"cubecl",
"derive-new",
"derive-new 0.7.0",
"log",
"serde",
]
@ -1874,6 +1871,17 @@ dependencies = [
"syn 2.0.87",
]
[[package]]
name = "derive-new"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "derive_arbitrary"
version = "1.3.2"
@ -1927,6 +1935,27 @@ dependencies = [
"syn 2.0.87",
]
[[package]]
name = "derive_more"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.87",
"unicode-xid",
]
[[package]]
name = "deunicode"
version = "1.6.0"
@ -2224,9 +2253,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]]
name = "fdeflate"
version = "0.3.5"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb"
dependencies = [
"simd-adler32",
]
@ -3191,9 +3220,9 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.9"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b"
checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4"
dependencies = [
"bytes",
"futures-channel",
@ -5669,9 +5698,9 @@ dependencies = [
[[package]]
name = "reqwest"
version = "0.12.8"
version = "0.12.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b"
checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f"
dependencies = [
"base64 0.22.1",
"bytes",
@ -5843,9 +5872,9 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.37"
version = "0.38.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811"
checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a"
dependencies = [
"bitflags 2.6.0",
"errno",
@ -5856,9 +5885,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.15"
version = "0.23.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993"
checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e"
dependencies = [
"log",
"once_cell",
@ -5970,9 +5999,9 @@ dependencies = [
[[package]]
name = "scc"
version = "2.2.2"
version = "2.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2c1f7fc6deb21665a9060dfc7d271be784669295a31babdcd4dd2c79ae8cbfb"
checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8"
dependencies = [
"sdd",
]
@ -6649,7 +6678,7 @@ name = "text-classification"
version = "0.16.0"
dependencies = [
"burn",
"derive-new",
"derive-new 0.7.0",
"serde",
"tokenizers",
]
@ -6659,7 +6688,7 @@ name = "text-generation"
version = "0.16.0"
dependencies = [
"burn",
"derive-new",
"derive-new 0.7.0",
"log",
"serde",
"tokenizers",
@ -6946,7 +6975,7 @@ checksum = "4126466aafe1c518cb5c23979c286903cb1d1ff1bc3b76891254a243a0ed1e15"
dependencies = [
"anyhow",
"clap 4.5.20",
"derive_more",
"derive_more 0.99.18",
"env_logger",
"log",
"rand",

View File

@ -29,7 +29,7 @@ version = "0.16.0"
[workspace.dependencies]
atomic_float = "1"
bytemuck = "1.19.0"
candle-core = { version = "0.6.0" }
candle-core = { version = "0.7" }
clap = { version = "4.5.20", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
@ -53,6 +53,7 @@ js-sys = "0.3.69"
libm = "0.2.9"
log = { default-features = false, version = "0.4.22" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
polars = { version = "0.41.3", features = ["lazy"] }
pretty_assertions = "1.4.1"
@ -117,7 +118,7 @@ bincode = { version = "2.0.0-rc.3", features = [
#
# The following packages disable the "std" feature for no_std compatibility
#
derive-new = { version = "0.6.0", default-features = false }
derive-new = { version = "0.7.0", default-features = false }
blas-src = { version = "0.10.0", default-features = false }
half = { version = "2.4.1", features = [
@ -152,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

View File

@ -73,9 +73,9 @@ impl AutodiffClient for ChannelClient {
.unwrap()
}
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
let node_id = root.node.id;
let grads = Gradients::new::<B, D>(root.node, root.primitive);
let grads = Gradients::new::<B>(root.node, root.primitive);
let (callback, receiver) = std::sync::mpsc::channel();
self.sender

View File

@ -20,10 +20,10 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
grad_1.to_data().assert_approx_eq(&expected, 3);
grad_1.to_data().assert_approx_eq(&expected, 5);
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
grad_2.to_data().assert_approx_eq(&expected, 3);
grad_2.to_data().assert_approx_eq(&expected, 5);
}
#[test]
@ -42,10 +42,10 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
grad_1.to_data().assert_approx_eq(&expected, 3);
grad_1.to_data().assert_approx_eq(&expected, 5);
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
grad_2.to_data().assert_approx_eq(&expected, 3);
grad_2.to_data().assert_approx_eq(&expected, 5);
let contains_nan = grad_2.contains_nan();
assert_eq!(contains_nan.into_scalar(), false);

View File

@ -46,7 +46,7 @@ mod tests {
x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}

View File

@ -62,7 +62,7 @@ mod tests {
x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}

View File

@ -97,7 +97,7 @@ mod tests {
x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}

View File

@ -124,7 +124,7 @@ mod tests {
x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}

View File

@ -40,21 +40,21 @@ mod tests {
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq(&grad_1_slice_1.to_data(), 3);
.assert_approx_eq(&grad_1_slice_1.to_data(), 5);
grad_1
.slice([1..2])
.to_data()
.assert_approx_eq(&grad_1_slice_2.to_data(), 3);
.assert_approx_eq(&grad_1_slice_2.to_data(), 5);
grad_2
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq(&grad_2_slice_1.to_data(), 3);
.assert_approx_eq(&grad_2_slice_1.to_data(), 5);
grad_2
.slice([1..2])
.to_data()
.assert_approx_eq(&grad_2_slice_2.to_data(), 3);
.assert_approx_eq(&grad_2_slice_2.to_data(), 5);
}
#[test]

View File

@ -265,15 +265,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
}
}
}

View File

@ -883,15 +883,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}

View File

@ -512,15 +512,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}

View File

@ -281,15 +281,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}

View File

@ -694,15 +694,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}

View File

@ -567,15 +567,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}

View File

@ -1408,7 +1408,7 @@ mod tests {
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq_diff(&weight_grad_actual.to_data(), 0.04);
}
}
}

View File

@ -67,11 +67,54 @@ mod transpose;
#[macro_export]
macro_rules! testgen_all {
// Avoid using paste dependency with no parameters
() => {
mod autodiff {
pub use super::*;
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
// Behavior
pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;
$crate::testgen_with_float_param!();
}
};
([$($float:ident),*]) => {
mod autodiff {
pub use super::*;
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;
::paste::paste! {
$(mod [<$float _ty>] {
pub use super::*;
pub type TestBackend = TestBackend2<$float, IntType>;
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;
type FloatType = $float;
$crate::testgen_with_float_param!();
})*
}
}
};
}
#[macro_export]
macro_rules! testgen_with_float_param {
() => {
// Behaviour
burn_autodiff::testgen_ad_broadcast!();
burn_autodiff::testgen_gradients!();
burn_autodiff::testgen_bridge!();

View File

@ -20,9 +20,9 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[82.1126, 99.0832], [82.1126, 99.0832]]);
grad_1.to_data().assert_approx_eq(&expected, 3);
grad_1.to_data().assert_approx_eq_diff(&expected, 0.02);
let expected = TensorData::from([[30.3093, 33.1204], [34.5819, 38.7694]]);
grad_2.to_data().assert_approx_eq(&expected, 3);
grad_2.to_data().assert_approx_eq(&expected, 2);
}
}

View File

@ -21,10 +21,10 @@ mod tests {
grad_1
.to_data()
.assert_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), false);
.assert_approx_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), 3);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), false);
.assert_approx_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), 3);
}
#[test]
@ -48,13 +48,13 @@ mod tests {
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
grad_1.to_data().assert_approx_eq(
&TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),
false,
3,
);
grad_2.to_data().assert_eq(
grad_2.to_data().assert_approx_eq(
&TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),
false,
3,
);
}
}

View File

@ -34,6 +34,8 @@ mod tests {
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
pub type FloatType = f32;
// test activation
burn_tensor::testgen_gelu!();
burn_tensor::testgen_prelu!();

View File

@ -523,7 +523,7 @@ mod tests {
// Should produce the same tokens.
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
.assert_approx_eq(&output_2.into_data(), 2);
}
#[test]

View File

@ -441,7 +441,7 @@ mod tests {
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
.assert_approx_eq(&output_2.into_data(), 2);
}
#[test]

View File

@ -2,38 +2,43 @@
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "CUDA backend for the Burn framework"
documentation = "https://docs.rs/burn-cuda"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu", "cuda"]
license.workspace = true
name = "burn-cuda"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
documentation = "https://docs.rs/burn-cuda"
version.workspace = true
[features]
default = ["fusion", "burn-jit/default", "cubecl/default"]
fusion = ["burn-fusion", "burn-jit/fusion"]
autotune = ["burn-jit/autotune"]
default = ["fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
fusion = ["burn-fusion", "burn-jit/fusion"]
std = ["burn-jit/std", "cubecl/std"]
[dependencies]
cubecl = { workspace = true, features = ["cuda"] }
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = ["cubecl-cuda"] }
burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
"cubecl-cuda",
] }
cubecl = { workspace = true, features = ["cuda"] }
half = { workspace = true }
bytemuck = { workspace = true }
half = { workspace = true }
log = { workspace = true }
derive-new = { workspace = true }
log = { workspace = true }
[dev-dependencies]
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
"export_tests",
] }
paste = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]

View File

@ -17,6 +17,7 @@ mod tests {
use burn_jit::JitBackend;
pub type TestRuntime = cubecl::cuda::CudaRuntime;
pub use half::{bf16, f16};
burn_jit::testgen_all!();
burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64]);
}

View File

@ -9,7 +9,7 @@ fn speech_command() {
println!("Item: {:?}", item);
println!("Item Length: {:?}", item.audio_samples.len());
println!("Label: {}", item.label.to_string());
println!("Label: {}", item.label);
assert_eq!(test.len(), 4890);
assert_eq!(item.label.to_string(), "Yes");

View File

@ -5,7 +5,7 @@ use crate::{
use hound::WavReader;
use serde::{Deserialize, Serialize};
use strum::EnumCount;
use strum::EnumCount as _;
use strum_macros::{Display, EnumCount, FromRepr};
type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples, SpeechItemRaw>;

View File

@ -367,7 +367,7 @@ mod tests {
let dataset: DataframeDataset<TestData> = DataframeDataset::new(df).unwrap();
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.is_empty(), false);
assert!(!dataset.is_empty());
let item = dataset.get(1).unwrap();
assert_eq!(
@ -439,7 +439,7 @@ mod tests {
let dataset = DataframeDataset::<PartialTestData>::new(df).unwrap();
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.is_empty(), false);
assert!(!dataset.is_empty());
let item = dataset.get(1).unwrap();
assert_eq!(

View File

@ -8,6 +8,7 @@ use hashbrown::HashMap;
///
/// It also contains all scalar values, which can change even for the same graph. They are sorted
/// in the order in which they appear in the graph.
#[allow(clippy::too_many_arguments)]
#[derive(new)]
pub struct Context<'a, H> {
/// The tensor mapping where local tensor id points to the updated tensor description.
@ -20,8 +21,22 @@ pub struct Context<'a, H> {
pub scalar_f16: &'a Vec<f16>,
/// BF16 scalars found in the graph in the order they appeared.
pub scalar_bf16: &'a Vec<bf16>,
/// Int scalars found in the graph in the order they appeared.
pub scalar_ints: &'a Vec<i32>,
/// i64 scalars found in the graph in the order they appeared.
pub scalar_i64: &'a Vec<i64>,
/// i32 scalars found in the graph in the order they appeared.
pub scalar_i32: &'a Vec<i32>,
/// i16 scalars found in the graph in the order they appeared.
pub scalar_i16: &'a Vec<i16>,
/// i8 scalars found in the graph in the order they appeared.
pub scalar_i8: &'a Vec<i8>,
/// u64 scalars found in the graph in the order they appeared.
pub scalar_u64: &'a Vec<u64>,
/// u32 scalars found in the graph in the order they appeared.
pub scalar_u32: &'a Vec<u32>,
/// u16 scalars found in the graph in the order they appeared.
pub scalar_u16: &'a Vec<u16>,
/// u8 scalars found in the graph in the order they appeared.
pub scalar_u8: &'a Vec<u8>,
}
#[derive(Default)]
@ -34,7 +49,14 @@ pub(crate) struct OperationConverter {
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
scalar_bf16: Vec<bf16>,
scalar_ints: Vec<i32>,
scalar_i64: Vec<i64>,
scalar_i32: Vec<i32>,
scalar_i16: Vec<i16>,
scalar_i8: Vec<i8>,
scalar_u64: Vec<u64>,
scalar_u32: Vec<u32>,
scalar_u16: Vec<u16>,
scalar_u8: Vec<u8>,
}
pub(crate) trait RelativeOps {
@ -63,7 +85,14 @@ impl OperationConverter {
scalar_f32: &self.scalar_f32,
scalar_f16: &self.scalar_f16,
scalar_bf16: &self.scalar_bf16,
scalar_ints: &self.scalar_ints,
scalar_i64: &self.scalar_i64,
scalar_i32: &self.scalar_i32,
scalar_i16: &self.scalar_i16,
scalar_i8: &self.scalar_i8,
scalar_u64: &self.scalar_u64,
scalar_u32: &self.scalar_u32,
scalar_u16: &self.scalar_u16,
scalar_u8: &self.scalar_u8,
}
}
@ -74,7 +103,14 @@ impl OperationConverter {
self.scalar_f32.clear();
self.scalar_f16.clear();
self.scalar_bf16.clear();
self.scalar_ints.clear();
self.scalar_i64.clear();
self.scalar_i32.clear();
self.scalar_i16.clear();
self.scalar_i8.clear();
self.scalar_u64.clear();
self.scalar_u32.clear();
self.scalar_u16.clear();
self.scalar_u8.clear();
}
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E, dtype: &DType) -> E {
@ -90,8 +126,18 @@ impl OperationConverter {
0.elem()
}
pub(crate) fn relative_int<E: Element>(&mut self, elem: &E) -> E {
self.scalar_ints.push(elem.elem());
pub(crate) fn relative_int<E: Element>(&mut self, elem: &E, dtype: &DType) -> E {
match dtype {
DType::I64 => self.scalar_i64.push(elem.elem()),
DType::I32 => self.scalar_i32.push(elem.elem()),
DType::I16 => self.scalar_i16.push(elem.elem()),
DType::I8 => self.scalar_i8.push(elem.elem()),
DType::U64 => self.scalar_u64.push(elem.elem()),
DType::U32 => self.scalar_u32.push(elem.elem()),
DType::U16 => self.scalar_u16.push(elem.elem()),
DType::U8 => self.scalar_u8.push(elem.elem()),
_ => todo!("Unsupported"),
}
// We return 0 so that the id from a scalar operation is the same no matter its scalar
// value.
0.elem()
@ -116,7 +162,7 @@ impl RelativeOps for OperationDescription {
),
OperationDescription::NumericInt(dtype, ops) => OperationDescription::NumericInt(
*dtype,
ops.to_relative(converter, |converter, e| converter.relative_int(e)),
ops.to_relative(converter, |converter, e| converter.relative_int(e, dtype)),
),
OperationDescription::Bool(ops) => {
OperationDescription::Bool(ops.to_relative(converter))

View File

@ -2,13 +2,13 @@
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Generic backend that can be compiled just-in-time to any shader language target"
documentation = "https://docs.rs/burn-jit"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu"]
license.workspace = true
name = "burn-jit"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit"
documentation = "https://docs.rs/burn-jit"
version.workspace = true
[features]
@ -22,6 +22,7 @@ export_tests = [
"burn-tensor/export_tests",
"burn-ndarray",
"fusion",
"paste",
]
fusion = ["burn-fusion"]
std = ["cubecl/std"]
@ -31,7 +32,8 @@ template = []
burn-common = { path = "../burn-common", version = "0.16.0" }
burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
"cubecl", "repr",
"cubecl",
"repr",
] }
cubecl = { workspace = true, features = ["linalg"] }
@ -56,6 +58,7 @@ hashbrown = { workspace = true }
# When exporting tests
burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, optional = true }
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true }
paste = { workspace = true, optional = true }
serial_test = { workspace = true, optional = true }
[package.metadata.docs.rs]

View File

@ -1,4 +1,5 @@
use cubecl::{
flex32,
prelude::{Float, Int, Numeric},
CubeElement,
};
@ -12,17 +13,26 @@ pub trait FloatElement: JitElement + Float {}
/// The int element type for the jit backend.
pub trait IntElement: JitElement + Int {}
impl JitElement for u64 {}
impl JitElement for u32 {}
impl JitElement for u16 {}
impl JitElement for u8 {}
impl JitElement for i64 {}
impl JitElement for i32 {}
impl JitElement for i16 {}
impl JitElement for i8 {}
impl JitElement for f64 {}
impl JitElement for f32 {}
impl JitElement for flex32 {}
impl JitElement for half::f16 {}
impl JitElement for half::bf16 {}
impl FloatElement for f64 {}
impl FloatElement for f32 {}
impl FloatElement for flex32 {}
impl FloatElement for half::bf16 {}
impl FloatElement for half::f16 {}
impl IntElement for i64 {}
impl IntElement for i32 {}
impl IntElement for i16 {}
impl IntElement for i8 {}

View File

@ -68,15 +68,29 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseOptimization<R> {
Arg::Input(index, precision, _) => match precision {
ElemwisePrecision::F32 => inputs.t_f32.values.get(index as usize),
ElemwisePrecision::F16 => inputs.t_f16.values.get(index as usize),
ElemwisePrecision::BF16 => inputs.t_bf16.values.get(index as usize),
ElemwisePrecision::U64 => inputs.t_u64.values.get(index as usize),
ElemwisePrecision::U32 => inputs.t_u32.values.get(index as usize),
ElemwisePrecision::U16 => inputs.t_u16.values.get(index as usize),
ElemwisePrecision::U8 => inputs.t_u8.values.get(index as usize),
ElemwisePrecision::I64 => inputs.t_i64.values.get(index as usize),
ElemwisePrecision::I32 => inputs.t_i32.values.get(index as usize),
ElemwisePrecision::I16 => inputs.t_i16.values.get(index as usize),
ElemwisePrecision::I8 => inputs.t_i8.values.get(index as usize),
_ => panic!("Invalid value"),
},
Arg::Output(index, precision, _) => match precision {
ElemwisePrecision::F32 => outputs.t_f32.values.get(index as usize),
ElemwisePrecision::F16 => outputs.t_f16.values.get(index as usize),
ElemwisePrecision::BF16 => outputs.t_bf16.values.get(index as usize),
ElemwisePrecision::U64 => outputs.t_u64.values.get(index as usize),
ElemwisePrecision::U32 => outputs.t_u32.values.get(index as usize),
ElemwisePrecision::U16 => outputs.t_u16.values.get(index as usize),
ElemwisePrecision::U8 => outputs.t_u8.values.get(index as usize),
ElemwisePrecision::I64 => outputs.t_i64.values.get(index as usize),
ElemwisePrecision::I32 => outputs.t_i32.values.get(index as usize),
ElemwisePrecision::I16 => outputs.t_i16.values.get(index as usize),
ElemwisePrecision::I8 => outputs.t_i8.values.get(index as usize),
_ => panic!("Invalid value"),
},
_ => panic!("Invalid value"),
@ -142,11 +156,11 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseOptimization<R> {
let mut output = u8::MAX;
for (handle, tensor) in handles_inputs.zip(inputs) {
output = u8::min(vectorization_input(handle, tensor), output);
output = Ord::min(vectorization_input(handle, tensor), output);
}
for tensor in outputs {
output = u8::min(vectorization_output(tensor), output);
output = Ord::min(vectorization_output(tensor), output);
}
output
@ -168,15 +182,29 @@ fn elemwise_fuse(
Arg::Input(index, precision, _) => match comptime![precision] {
ElemwisePrecision::F32 => inputs.t_f32.index(index).len(),
ElemwisePrecision::F16 => inputs.t_f16.index(index).len(),
ElemwisePrecision::BF16 => inputs.t_bf16.index(index).len(),
ElemwisePrecision::U64 => inputs.t_u64.index(index).len(),
ElemwisePrecision::U32 => inputs.t_u32.index(index).len(),
ElemwisePrecision::U16 => inputs.t_u16.index(index).len(),
ElemwisePrecision::U8 => inputs.t_u8.index(index).len(),
ElemwisePrecision::I64 => inputs.t_i64.index(index).len(),
ElemwisePrecision::I32 => inputs.t_i32.index(index).len(),
ElemwisePrecision::I16 => inputs.t_i16.index(index).len(),
ElemwisePrecision::I8 => inputs.t_i8.index(index).len(),
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Output(index, precision, _) => match comptime![precision] {
ElemwisePrecision::F32 => outputs.t_f32.index(index).len(),
ElemwisePrecision::F16 => outputs.t_f16.index(index).len(),
ElemwisePrecision::BF16 => outputs.t_bf16.index(index).len(),
ElemwisePrecision::U64 => outputs.t_u64.index(index).len(),
ElemwisePrecision::U32 => outputs.t_u32.index(index).len(),
ElemwisePrecision::U16 => outputs.t_u16.index(index).len(),
ElemwisePrecision::U8 => outputs.t_u8.index(index).len(),
ElemwisePrecision::I64 => outputs.t_i64.index(index).len(),
ElemwisePrecision::I32 => outputs.t_i32.index(index).len(),
ElemwisePrecision::I16 => outputs.t_i16.index(index).len(),
ElemwisePrecision::I8 => outputs.t_i8.index(index).len(),
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
_ => comptime![panic!("Invalid ref layout.")],

View File

@ -1,337 +0,0 @@
use cubecl::calculate_num_elems_dyn_rank;
use cubecl::prelude::*;
use crate::fusion::strides_dyn_rank;
use crate::fusion::JitFusionHandle;
use crate::kernel::Kernel;
use crate::JitRuntime;
use burn_fusion::stream::Context;
use burn_tensor::repr::TensorDescription;
use burn_tensor::repr::TensorStatus;
use cubecl::client::ComputeClient;
use cubecl::server::Binding;
use cubecl::tune::AutotuneOperation;
use std::marker::PhantomData;
use std::sync::Arc;
use super::tracing::ExecutionInfo;
#[derive(new)]
pub struct FusionKernel<R: JitRuntime> {
id: u64, // Same ID for all different settings.
info: Arc<KernelExpansion>,
settings: KernelSettings,
runtime_info: Vec<OutputRuntimeInfo>,
cube_count: CubeCount<R::Server>,
_runtime: PhantomData<R>,
}
pub trait FusionKernelFactory<R: JitRuntime> {
/// Create a new kernel.
fn create(
&self,
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
stateful: bool, // Should be set to false when running autotune.
) -> FusionKernel<R>;
}
/// An instantiation of a [kernel](Kernel) that can be executed.
#[derive(new)]
pub struct ExecutableKernel<R: JitRuntime> {
kernel: Box<dyn CubeTask<R::Compiler>>,
cube_count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>,
}
/// An instantiation of a [kernel](Kernel) that can be autotuned.
///
/// The main difference with an [executable kernel](ExecutableKernel) is that this kernel can be
/// cloned and executed multiple times to properly collect benchmarks.
///
/// The clone function used is defined in the trait [AutotuneOperation] instead of [Clone].
#[derive(new)]
pub struct AutotunableKernel<R: JitRuntime> {
kernel: Arc<dyn CubeTask<R::Compiler>>,
count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>,
}
// Information related to the output of this kernel.
#[derive(Debug)]
pub enum OutputRuntimeInfo {
Inplace { input_index: usize },
Array { size: usize },
}
impl<R: JitRuntime> ExecutableKernel<R> {
/// Execute the kernel.
pub fn execute(self) {
unsafe {
self.client
.execute_unchecked(self.kernel, self.cube_count, self.bindings)
}
}
}
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
fn execute(self: Box<Self>) {
self.client
.execute(Box::new(self.kernel), self.count, self.bindings)
}
fn clone(&self) -> Box<dyn AutotuneOperation> {
Box::new(Self {
kernel: self.kernel.clone(),
count: self.count.clone(),
bindings: self.bindings.clone(),
client: self.client.clone(),
})
}
}
impl<R: JitRuntime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
fn from(value: ExecutableKernel<R>) -> Self {
Self {
kernel: Arc::new(value.kernel),
count: value.cube_count.clone(),
bindings: value.bindings,
client: value.client,
}
}
}
impl<R: JitRuntime> FusionKernel<R> {
pub fn create<K>(
factory: &K,
running_info: &ExecutionInfo<'_>,
context: &mut Context<'_, JitFusionHandle<R>>,
device: R::Device,
client: ComputeClient<R::Server, R::Channel>,
stateful: bool,
) -> ExecutableKernel<R>
where
K: FusionKernelFactory<R>,
{
let (handles_input, inputs_description_updated, outputs_description_updated) =
process_inputs_outputs(
&running_info.inputs,
&running_info.outputs,
context,
stateful,
);
let fusion_kernel = factory.create(
&handles_input,
&inputs_description_updated,
&outputs_description_updated,
stateful,
);
let rank_input = running_info
.inputs
.first()
.map(|desc| desc.shape.len())
.unwrap_or(1);
let rank_output = running_info
.outputs
.first()
.map(|desc| desc.shape.len())
.unwrap_or(1);
let rank = usize::max(rank_input, rank_output);
let num_tensors = running_info.inputs.len() + running_info.outputs.len();
// The buffer starts with the rank, then each tensor shape and stride.
let info_size = (num_tensors * rank * 2) + 1;
let mut num_handles = num_tensors + 1;
if running_info.scalars.num_f32 > 0 {
num_handles += 1;
}
if running_info.scalars.num_f16 > 0 {
num_handles += 1;
}
if running_info.scalars.num_bf16 > 0 {
num_handles += 1;
}
if running_info.scalars.num_int > 0 {
num_handles += 1;
}
let mut info = Vec::with_capacity(info_size);
let mut bindings = Vec::with_capacity(num_handles);
let mut output_register = Vec::with_capacity(outputs_description_updated.len());
// We register the info and handles for the inputs.
for (handle, tensor) in handles_input.iter().zip(inputs_description_updated.iter()) {
register_info_tensor(&mut info, tensor, handle);
bindings.push(handle.handle.clone().binding());
}
// We register the info and handles for the outputs.
for (tensor, output_info) in outputs_description_updated
.iter()
.zip(fusion_kernel.runtime_info.iter())
{
match output_info {
// Use the input inplace for this output.
OutputRuntimeInfo::Inplace { input_index } => {
let input = handles_input.get(*input_index).unwrap();
let handle_fusion = JitFusionHandle {
client: client.clone(),
device: device.clone(),
strides: strides_dyn_rank(&tensor.shape),
handle: input.handle.clone(),
};
output_register.push((tensor.id, handle_fusion));
}
// Create a new buffer for this output.
OutputRuntimeInfo::Array { size } => {
let handle_fusion = JitFusionHandle {
client: client.clone(),
device: device.clone(),
strides: strides_dyn_rank(&tensor.shape),
handle: client.empty(*size),
};
register_info_tensor(&mut info, tensor, &handle_fusion);
bindings.push(handle_fusion.handle.clone().binding());
output_register.push((tensor.id, handle_fusion));
}
};
}
// [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len]
if R::require_array_lengths() {
for input in inputs_description_updated.iter() {
let len = calculate_num_elems_dyn_rank(&input.shape);
info.push(len as u32);
}
for output in outputs_description_updated.iter() {
let len = calculate_num_elems_dyn_rank(&output.shape);
info.push(len as u32);
}
}
// Create the info buffer.
bindings.push(client.create(bytemuck::cast_slice(&info)).binding());
// Finally we finish with the named bindings.
if running_info.scalars.num_f32 > 0 {
let bytes = bytemuck::cast_slice(&context.scalar_f32[0..running_info.scalars.num_f32]);
bindings.push(client.create(bytes).binding());
}
if running_info.scalars.num_f16 > 0 {
let bytes = bytemuck::cast_slice(&context.scalar_f16[0..running_info.scalars.num_f16]);
bindings.push(client.create(bytes).binding());
}
if running_info.scalars.num_bf16 > 0 {
let bytes =
bytemuck::cast_slice(&context.scalar_bf16[0..running_info.scalars.num_bf16]);
bindings.push(client.create(bytes).binding());
}
if running_info.scalars.num_int > 0 {
bindings.push(
client
.create(bytemuck::cast_slice(
&context.scalar_ints[0..running_info.scalars.num_int],
))
.binding(),
);
}
// We have to register the output handles to the context.
for (id, handle) in output_register {
context.handles.register_handle(id, handle);
}
let cube_count = fusion_kernel.cube_count.clone();
ExecutableKernel::new(
Box::new(KernelTask::<R::Compiler, _>::new(fusion_kernel)),
cube_count,
bindings,
client,
)
}
}
impl<R: JitRuntime> Kernel for FusionKernel<R> {
fn define(&self) -> KernelDefinition {
log::info!("Compiling ... {:?}", self.id());
KernelIntegrator::new(self.info.as_ref().clone()).integrate(self.settings.clone())
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>().info((self.settings.clone(), self.id))
}
}
fn register_info_tensor<R: JitRuntime>(
info: &mut Vec<u32>,
tensor: &TensorDescription,
handle: &JitFusionHandle<R>,
) {
if info.is_empty() {
info.push(handle.strides.len() as u32);
}
for s in handle.strides.iter() {
info.push(*s as u32);
}
for s in tensor.shape.iter() {
info.push(*s as u32);
}
}
fn process_inputs_outputs<'a, R>(
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
context: &'a mut Context<'_, JitFusionHandle<R>>,
stateful: bool,
) -> (
Vec<JitFusionHandle<R>>,
Vec<&'a TensorDescription>,
Vec<&'a TensorDescription>,
)
where
R: JitRuntime,
{
let mut inputs_description_updated = Vec::with_capacity(inputs.len());
let mut outputs_description_updated = Vec::with_capacity(outputs.len());
let mut handles_input = Vec::new();
for tensor in inputs.iter() {
let status = if stateful {
&tensor.status // Important to take the status of the relative graph and not
// the global graph, since the status of the global graph
// might be of a later operation on the same tensor id.
} else {
&TensorStatus::ReadOnly
};
let tensor = context.tensors.get(&tensor.id).unwrap();
let handle = context.handles.get_handle(&tensor.id, status);
handles_input.push(handle);
inputs_description_updated.push(tensor);
}
for tensor in outputs.iter() {
let tensor = context.tensors.get(&tensor.id).unwrap();
outputs_description_updated.push(tensor);
}
(
handles_input,
inputs_description_updated,
outputs_description_updated,
)
}

View File

@ -31,6 +31,24 @@ pub fn read<C: CubePrimitive>(
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::BF16 => {
let tensor = inputs.t_bf16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U64 => {
let tensor = inputs.t_u64.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U32 => {
let tensor = inputs.t_u32.index(pos);
let offset = match layout {
@ -40,6 +58,33 @@ pub fn read<C: CubePrimitive>(
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U16 => {
let tensor = inputs.t_u16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U8 => {
let tensor = inputs.t_u8.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I64 => {
let tensor = inputs.t_i64.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I32 => {
let tensor = inputs.t_i32.index(pos);
let offset = match layout {
@ -49,6 +94,24 @@ pub fn read<C: CubePrimitive>(
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I16 => {
let tensor = inputs.t_i16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I8 => {
let tensor = inputs.t_i8.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Output(pos, precision, layout) => match comptime![precision] {
@ -70,6 +133,24 @@ pub fn read<C: CubePrimitive>(
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::BF16 => {
let tensor = outputs.t_bf16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U64 => {
let tensor = outputs.t_u64.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U32 => {
let tensor = outputs.t_u32.index(pos);
let offset = match layout {
@ -79,6 +160,33 @@ pub fn read<C: CubePrimitive>(
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U16 => {
let tensor = outputs.t_u16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::U8 => {
let tensor = outputs.t_u8.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I64 => {
let tensor = outputs.t_i64.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I32 => {
let tensor = outputs.t_i32.index(pos);
let offset = match layout {
@ -88,22 +196,52 @@ pub fn read<C: CubePrimitive>(
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I16 => {
let tensor = outputs.t_i16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
ElemwisePrecision::I8 => {
let tensor = outputs.t_i8.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
Line::cast_from(tensor[offset])
}
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Local(pos, precision) => match comptime![precision] {
ElemwisePrecision::F32 => Line::cast_from(locals.l_f32.find(pos)),
ElemwisePrecision::F16 => Line::cast_from(locals.l_f16.find(pos)),
ElemwisePrecision::BF16 => Line::cast_from(locals.l_bf16.find(pos)),
ElemwisePrecision::U64 => Line::cast_from(locals.l_u64.find(pos)),
ElemwisePrecision::U32 => Line::cast_from(locals.l_u32.find(pos)),
ElemwisePrecision::U16 => Line::cast_from(locals.l_u16.find(pos)),
ElemwisePrecision::U8 => Line::cast_from(locals.l_u8.find(pos)),
ElemwisePrecision::I64 => Line::cast_from(locals.l_i64.find(pos)),
ElemwisePrecision::I32 => Line::cast_from(locals.l_i32.find(pos)),
ElemwisePrecision::I16 => Line::cast_from(locals.l_i16.find(pos)),
ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)),
ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)),
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Scalar(pos, precision) => match comptime![precision] {
ElemwisePrecision::F32 => Line::cast_from(*inputs.s_f32.index(pos)),
ElemwisePrecision::F16 => Line::cast_from(*inputs.s_f16.index(pos)),
ElemwisePrecision::BF16 => Line::cast_from(*inputs.s_bf16.index(pos)),
ElemwisePrecision::U64 => Line::cast_from(*inputs.s_u64.index(pos)),
ElemwisePrecision::U32 => Line::cast_from(*inputs.s_u32.index(pos)),
ElemwisePrecision::U16 => Line::cast_from(*inputs.s_u16.index(pos)),
ElemwisePrecision::U8 => Line::cast_from(*inputs.s_u8.index(pos)),
ElemwisePrecision::I64 => Line::cast_from(*inputs.s_i64.index(pos)),
ElemwisePrecision::I32 => Line::cast_from(*inputs.s_i32.index(pos)),
ElemwisePrecision::BF16 => comptime![panic!("Can't write into inputs or scalars")],
ElemwisePrecision::I16 => Line::cast_from(*inputs.s_i16.index(pos)),
ElemwisePrecision::I8 => Line::cast_from(*inputs.s_i8.index(pos)),
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Literal(val, _precision) => Line::cast_from(val.runtime()),
@ -143,6 +281,26 @@ pub fn write<C: CubePrimitive>(
let tensor = outputs.t_f16.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::BF16 => {
let tensor = outputs.t_bf16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_bf16.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::U64 => {
let tensor = outputs.t_u64.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_u64.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::U32 => {
let tensor = outputs.t_u32.index(pos);
let offset = match layout {
@ -153,6 +311,36 @@ pub fn write<C: CubePrimitive>(
let tensor = outputs.t_u32.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::U16 => {
let tensor = outputs.t_u16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_u16.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::U8 => {
let tensor = outputs.t_u8.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_u8.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::I64 => {
let tensor = outputs.t_i64.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_i64.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::I32 => {
let tensor = outputs.t_i32.index(pos);
let offset = match layout {
@ -163,15 +351,41 @@ pub fn write<C: CubePrimitive>(
let tensor = outputs.t_i32.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::I16 => {
let tensor = outputs.t_i16.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_i16.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
ElemwisePrecision::I8 => {
let tensor = outputs.t_i8.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
};
let tensor = outputs.t_i8.index_mut(pos);
tensor[offset] = Line::cast_from(value);
}
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Local(pos, precision) => match comptime![precision] {
ElemwisePrecision::F32 => locals.l_f32.insert(pos, Line::cast_from(value)),
ElemwisePrecision::F16 => locals.l_f16.insert(pos, Line::cast_from(value)),
ElemwisePrecision::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)),
ElemwisePrecision::U64 => locals.l_u64.insert(pos, Line::cast_from(value)),
ElemwisePrecision::U32 => locals.l_u32.insert(pos, Line::cast_from(value)),
ElemwisePrecision::U16 => locals.l_u16.insert(pos, Line::cast_from(value)),
ElemwisePrecision::U8 => locals.l_u8.insert(pos, Line::cast_from(value)),
ElemwisePrecision::I64 => locals.l_i64.insert(pos, Line::cast_from(value)),
ElemwisePrecision::I32 => locals.l_i32.insert(pos, Line::cast_from(value)),
ElemwisePrecision::I16 => locals.l_i16.insert(pos, Line::cast_from(value)),
ElemwisePrecision::I8 => locals.l_i8.insert(pos, Line::cast_from(value)),
ElemwisePrecision::Bool => locals.l_bool.insert(pos, Line::cast_from(value)),
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
_ => comptime![panic!("Can't write into inputs and scalars")],
}
@ -195,14 +409,42 @@ fn get_offset<C: CubePrimitive>(
let layout = inputs.t_f16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::BF16 => {
let layout = inputs.t_bf16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U64 => {
let layout = inputs.t_u64.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U32 => {
let layout = inputs.t_u32.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U16 => {
let layout = inputs.t_u16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U8 => {
let layout = inputs.t_u8.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I64 => {
let layout = inputs.t_i64.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I32 => {
let layout = inputs.t_i32.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I16 => {
let layout = inputs.t_i16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I8 => {
let layout = inputs.t_i8.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
Arg::Output(index, precision, _) => match comptime![precision] {
@ -214,14 +456,42 @@ fn get_offset<C: CubePrimitive>(
let layout = outputs.t_f16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::BF16 => {
let layout = outputs.t_bf16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U64 => {
let layout = outputs.t_u64.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U32 => {
let layout = outputs.t_u32.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U16 => {
let layout = outputs.t_u16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::U8 => {
let layout = outputs.t_u8.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I64 => {
let layout = outputs.t_i64.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I32 => {
let layout = outputs.t_i32.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I16 => {
let layout = outputs.t_i16.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
ElemwisePrecision::I8 => {
let layout = outputs.t_i8.index(index);
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
}
_ => comptime![panic!("Unsupported precision {precision:?}")],
},
_ => comptime![panic!("Invalid ref layout.")],

View File

@ -79,13 +79,25 @@ pub struct GlobalArgs {
pub t_f32: Sequence<Tensor<Line<f32>>>,
pub t_f16: Sequence<Tensor<Line<f16>>>,
pub t_bf16: Sequence<Tensor<Line<bf16>>>,
pub t_i64: Sequence<Tensor<Line<i64>>>,
pub t_i32: Sequence<Tensor<Line<i32>>>,
pub t_i16: Sequence<Tensor<Line<i16>>>,
pub t_i8: Sequence<Tensor<Line<i8>>>,
pub t_u64: Sequence<Tensor<Line<u64>>>,
pub t_u32: Sequence<Tensor<Line<u32>>>,
pub t_u16: Sequence<Tensor<Line<u16>>>,
pub t_u8: Sequence<Tensor<Line<u8>>>,
pub s_f32: Sequence<f32>,
pub s_f16: Sequence<f16>,
pub s_bf16: Sequence<bf16>,
pub s_i64: Sequence<i64>,
pub s_i32: Sequence<i32>,
pub s_i16: Sequence<i16>,
pub s_i8: Sequence<i8>,
pub s_u64: Sequence<u64>,
pub s_u32: Sequence<u32>,
pub s_u16: Sequence<u16>,
pub s_u8: Sequence<u8>,
}
#[derive(CubeType, Clone)]
@ -95,8 +107,14 @@ pub struct LocalArgs {
pub l_f32: Registry<u32, Line<f32>>,
pub l_f16: Registry<u32, Line<f16>>,
pub l_bf16: Registry<u32, Line<bf16>>,
pub l_i64: Registry<u32, Line<i64>>,
pub l_i32: Registry<u32, Line<i32>>,
pub l_i16: Registry<u32, Line<i16>>,
pub l_i8: Registry<u32, Line<i8>>,
pub l_u64: Registry<u32, Line<u64>>,
pub l_u32: Registry<u32, Line<u32>>,
pub l_u16: Registry<u32, Line<u16>>,
pub l_u8: Registry<u32, Line<u8>>,
pub l_bool: Registry<u32, Line<bool>>,
}
@ -123,9 +141,13 @@ pub enum ElemwisePrecision {
F32,
F16,
BF16,
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
Bool,
}
@ -139,8 +161,18 @@ impl From<Elem> for ElemwisePrecision {
cubecl::ir::FloatKind::F32 => Self::F32,
_ => panic!("Unsupported precision for fusion: {value}"),
},
Elem::Int(cubecl::ir::IntKind::I32) => Self::I32,
Elem::UInt => Self::U32,
Elem::Int(kind) => match kind {
cubecl::ir::IntKind::I64 => Self::I64,
cubecl::ir::IntKind::I32 => Self::I32,
cubecl::ir::IntKind::I16 => Self::I16,
cubecl::ir::IntKind::I8 => Self::I8,
},
Elem::UInt(kind) => match kind {
cubecl::ir::UIntKind::U64 => Self::U64,
cubecl::ir::UIntKind::U32 => Self::U32,
cubecl::ir::UIntKind::U16 => Self::U16,
cubecl::ir::UIntKind::U8 => Self::U8,
},
Elem::Bool => Self::Bool,
_ => panic!("Unsupported precision for fusion: {value}"),
}
@ -153,9 +185,13 @@ impl From<DType> for ElemwisePrecision {
DType::F32 => Self::F32,
DType::F16 => Self::F16,
DType::BF16 => Self::BF16,
DType::I64 => Self::I64,
DType::I32 => Self::I32,
DType::I16 => Self::I16,
DType::I8 => Self::I8,
DType::U64 => Self::U64,
DType::U32 => Self::U32,
DType::U16 => Self::U16,
DType::U8 => Self::U8,
DType::Bool => Self::Bool,
_ => panic!("Unsupported"),

View File

@ -19,8 +19,14 @@ pub fn fuse_on_write<E: CubePrimitive>(
l_f32: Registry::<u32, Line<f32>>::new(),
l_f16: Registry::<u32, Line<f16>>::new(),
l_bf16: Registry::<u32, Line<bf16>>::new(),
l_i64: Registry::<u32, Line<i64>>::new(),
l_i32: Registry::<u32, Line<i32>>::new(),
l_i16: Registry::<u32, Line<i16>>::new(),
l_i8: Registry::<u32, Line<i8>>::new(),
l_u64: Registry::<u32, Line<u64>>::new(),
l_u32: Registry::<u32, Line<u32>>::new(),
l_u16: Registry::<u32, Line<u16>>::new(),
l_u8: Registry::<u32, Line<u8>>::new(),
l_bool: Registry::<u32, Line<bool>>::new(),
};
@ -48,12 +54,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
add::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
add::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
add::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
add::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
add::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
add::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
add::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
add::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
add::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Div(op) => match op.out.precision() {
@ -66,12 +90,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
div::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
div::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
div::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
div::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
div::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
div::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
div::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
div::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
div::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Sub(op) => match op.out.precision() {
@ -84,12 +126,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
sub::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
sub::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
sub::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
sub::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
sub::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
sub::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
sub::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
sub::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
sub::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Mul(op) => match op.out.precision() {
@ -102,12 +162,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
mul::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
mul::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
mul::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
mul::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
mul::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
mul::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
mul::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
mul::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
mul::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Powf(op) => match op.out.precision() {
@ -144,12 +222,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
abs::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
assign::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
assign::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
assign::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
assign::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
abs::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
abs::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
abs::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
abs::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Log(op) => match op.out.precision() {
@ -198,16 +294,33 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
assign::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
assign::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
assign::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
assign::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
assign::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
assign::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
assign::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
assign::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
assign::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::Bool => {
assign::<bool>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Exp(op) => match op.out.precision() {
ElemwisePrecision::F32 => {
@ -267,12 +380,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
equal::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
equal::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
equal::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
equal::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
equal::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
equal::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
equal::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
equal::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Greater(op) => match op.lhs.precision() {
@ -285,12 +416,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
greater::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
greater::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
greater::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
greater::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
greater::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
greater::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
greater::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
greater::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
greater::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::GreaterEqual(op) => match op.lhs.precision() {
@ -303,12 +452,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
greater_equal::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
greater_equal::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
greater_equal::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
greater_equal::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
greater_equal::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
greater_equal::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
greater_equal::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
greater_equal::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
greater_equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::Lower(op) => match op.lhs.precision() {
@ -321,12 +488,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
lower::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
lower::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
lower::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
lower::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
lower::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
lower::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
lower::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
lower::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
lower::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::LowerEqual(op) => match op.lhs.precision() {
@ -339,12 +524,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::BF16 => {
lower_equal::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I64 => {
lower_equal::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I32 => {
lower_equal::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I16 => {
lower_equal::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::I8 => {
lower_equal::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U64 => {
lower_equal::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U32 => {
lower_equal::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U16 => {
lower_equal::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
}
ElemwisePrecision::U8 => {
lower_equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
},
ElemwiseOp::ConditionalAssign {
@ -386,6 +589,17 @@ pub fn fuse_on_write<E: CubePrimitive>(
out,
config,
),
ElemwisePrecision::I64 => conditional_assign::<i64>(
inputs,
outputs,
&mut locals,
write_pos,
cond,
lhs,
rhs,
out,
config,
),
ElemwisePrecision::I32 => conditional_assign::<i32>(
inputs,
outputs,
@ -397,6 +611,39 @@ pub fn fuse_on_write<E: CubePrimitive>(
out,
config,
),
ElemwisePrecision::I16 => conditional_assign::<i16>(
inputs,
outputs,
&mut locals,
write_pos,
cond,
lhs,
rhs,
out,
config,
),
ElemwisePrecision::I8 => conditional_assign::<i8>(
inputs,
outputs,
&mut locals,
write_pos,
cond,
lhs,
rhs,
out,
config,
),
ElemwisePrecision::U64 => conditional_assign::<u64>(
inputs,
outputs,
&mut locals,
write_pos,
cond,
lhs,
rhs,
out,
config,
),
ElemwisePrecision::U32 => conditional_assign::<u32>(
inputs,
outputs,
@ -408,6 +655,28 @@ pub fn fuse_on_write<E: CubePrimitive>(
out,
config,
),
ElemwisePrecision::U16 => conditional_assign::<u16>(
inputs,
outputs,
&mut locals,
write_pos,
cond,
lhs,
rhs,
out,
config,
),
ElemwisePrecision::U8 => conditional_assign::<u8>(
inputs,
outputs,
&mut locals,
write_pos,
cond,
lhs,
rhs,
out,
config,
),
_ => comptime![panic!("Unsupported precision {op:?}")],
},
}

View File

@ -333,6 +333,18 @@ impl FuseOnWriteTrace {
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
);
for hi in handle_inputs.iter() {
@ -341,8 +353,14 @@ impl FuseOnWriteTrace {
ElemwisePrecision::F32 => inputs.t_f32.push(arg),
ElemwisePrecision::F16 => inputs.t_f16.push(arg),
ElemwisePrecision::BF16 => inputs.t_bf16.push(arg),
ElemwisePrecision::I64 => inputs.t_i64.push(arg),
ElemwisePrecision::I32 => inputs.t_i32.push(arg),
ElemwisePrecision::I16 => inputs.t_i16.push(arg),
ElemwisePrecision::I8 => inputs.t_i8.push(arg),
ElemwisePrecision::U64 => inputs.t_u64.push(arg),
ElemwisePrecision::U32 => inputs.t_u32.push(arg),
ElemwisePrecision::U16 => inputs.t_u16.push(arg),
ElemwisePrecision::U8 => inputs.t_u8.push(arg),
_ => panic!("Unsupported input precision {:?}", hi.precision),
};
}
@ -356,10 +374,30 @@ impl FuseOnWriteTrace {
ElemwisePrecision::F16 => {
inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i]))
}
ElemwisePrecision::I32 => {
inputs.s_i32.push(ScalarArg::new(context.scalar_ints[i]))
ElemwisePrecision::BF16 => {
inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i]))
}
_ => todo!(),
ElemwisePrecision::I64 => {
inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i]))
}
ElemwisePrecision::I32 => {
inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i]))
}
ElemwisePrecision::I16 => {
inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i]))
}
ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])),
ElemwisePrecision::U64 => {
inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i]))
}
ElemwisePrecision::U32 => {
inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i]))
}
ElemwisePrecision::U16 => {
inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i]))
}
ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])),
ElemwisePrecision::Bool => todo!(),
}
}
}
@ -383,6 +421,18 @@ impl FuseOnWriteTrace {
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
SequenceArg::new(),
);
for item in handle_outputs.iter() {
match item {
@ -392,8 +442,15 @@ impl FuseOnWriteTrace {
} => match precision {
ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)),
ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)),
_ => todo!(),
},
HandleOutput::Owned {
@ -406,11 +463,17 @@ impl FuseOnWriteTrace {
match precision {
ElemwisePrecision::F32 => outputs.t_f32.push(arg),
ElemwisePrecision::F16 => outputs.t_f16.push(arg),
ElemwisePrecision::BF16 => outputs.t_bf16.push(arg),
ElemwisePrecision::I64 => outputs.t_i64.push(arg),
ElemwisePrecision::I32 => outputs.t_i32.push(arg),
ElemwisePrecision::I16 => outputs.t_i16.push(arg),
ElemwisePrecision::I8 => outputs.t_i8.push(arg),
ElemwisePrecision::U64 => outputs.t_u64.push(arg),
ElemwisePrecision::U32 => outputs.t_u32.push(arg),
ElemwisePrecision::U16 => outputs.t_u16.push(arg),
ElemwisePrecision::U8 => outputs.t_u8.push(arg),
// Bools are encoded as u32.
ElemwisePrecision::Bool => outputs.t_u32.push(arg),
_ => todo!(),
};
}
}

View File

@ -127,7 +127,7 @@ pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
let vectorization_factor_rhs =
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1);
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs);
let mut shape_out = vec![0; ndims];
lhs.shape

View File

@ -119,7 +119,7 @@ pub(crate) fn launch_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(
let vectorization_factor_rhs =
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1);
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs);
let mut shape_out = vec![0; ndims];
lhs.shape
@ -169,7 +169,7 @@ pub(crate) fn launch_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(
JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides)
} else {
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
@ -225,7 +225,7 @@ pub(crate) fn launch_scalar_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>
tensor.strides,
)
} else {
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<u32>());
let output = JitTensor::new(
tensor.client.clone(),
buffer,

View File

@ -60,7 +60,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
options.groups,
out_h,
out_w,
&input.device,
&input.client,
) {
panic!(
"Requirements for implicit GEMM not met:
@ -84,7 +84,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
let slice_size = pad_kh * pad_kw * pad_in_channels;
let (cmma_m, cmma_n, cmma_k) =
find_cmma_size::<R, f16, F>(&input.device, gemm_m, gemm_k, gemm_n).unwrap();
find_cmma_size::<R, f16, F>(&input.client, gemm_m, gemm_k, gemm_n).unwrap();
let cube_dim_x = 128;
let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2);
@ -187,7 +187,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 {
let channels = channels as u8;
let elems_per_thread = elems_per_thread as u8;
let smaller = u8::min(channels, elems_per_thread);
let smaller = Ord::min(channels, elems_per_thread);
(1..=smaller)
.rev()
.filter(|it| supported_vecs.contains(it))
@ -663,7 +663,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
groups: usize,
out_h: usize,
out_w: usize,
device: &R::Device,
client: &ComputeClient<R::Server, R::Channel>,
) -> bool {
let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_size[0], kernel_size[1]);
let batch_size = padded_batch_size(batch_size, out_h, out_w);
@ -673,7 +673,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
let gemm_n = out_channels;
let gemm_k = in_channels * kernel_h * kernel_w;
let size = find_cmma_size::<R, f16, E>(device, gemm_m as u32, gemm_k as u32, gemm_n as u32);
let size = find_cmma_size::<R, f16, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);
if let Some((cmma_m, cmma_k, cmma_n)) = size {
let warps_per_cube = 8;
@ -716,12 +716,12 @@ fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize {
}
fn find_cmma_size<R: JitRuntime, F: Float, FAcc: Float>(
device: &R::JitDevice,
client: &ComputeClient<R::Server, R::Channel>,
gemm_m: u32,
gemm_k: u32,
gemm_n: u32,
) -> Option<(u32, u32, u32)> {
supported_cmma_sizes::<R, F, FAcc>(device)
supported_cmma_sizes::<R, F, FAcc>(client)
.into_iter()
.find(|(m, k, n)| {
gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0
@ -730,7 +730,7 @@ fn find_cmma_size<R: JitRuntime, F: Float, FAcc: Float>(
}
fn supported_cmma_sizes<R: JitRuntime, F: Float, FAcc: Float>(
device: &R::JitDevice,
client: &ComputeClient<R::Server, R::Channel>,
) -> Vec<(u8, u8, u8)> {
let requested_sizes = [(16, 16, 16), (32, 16, 8), (8, 16, 32)];
@ -738,9 +738,7 @@ fn supported_cmma_sizes<R: JitRuntime, F: Float, FAcc: Float>(
.iter()
.copied()
.filter(|(m, k, n)| {
R::client(device)
.properties()
.feature_enabled(Feature::Cmma {
client.properties().feature_enabled(Feature::Cmma {
a: F::as_elem(),
b: F::as_elem(),
c: FAcc::as_elem(),

View File

@ -3,6 +3,7 @@ use cubecl::{
ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -43,14 +44,14 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_shape_0 = scope.create_local(Elem::UInt);
let input_shape_1 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let input_stride_0 = scope.create_local(u32::as_elem());
let input_stride_1 = scope.create_local(u32::as_elem());
let input_stride_2 = scope.create_local(u32::as_elem());
let input_stride_3 = scope.create_local(u32::as_elem());
let input_shape_0 = scope.create_local(u32::as_elem());
let input_shape_1 = scope.create_local(u32::as_elem());
let input_shape_2 = scope.create_local(u32::as_elem());
let input_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, input_stride_0 = stride(input, 0u32));
cpa!(scope, input_stride_1 = stride(input, 1u32));
cpa!(scope, input_stride_2 = stride(input, 2u32));
@ -60,14 +61,14 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, input_shape_2 = shape(input, 2u32));
cpa!(scope, input_shape_3 = shape(input, 3u32));
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, output_stride_0 = stride(output, 0u32));
cpa!(scope, output_stride_1 = stride(output, 1u32));
cpa!(scope, output_stride_2 = stride(output, 2u32));
@ -77,14 +78,14 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let weight_stride_0 = scope.create_local(Elem::UInt);
let weight_stride_1 = scope.create_local(Elem::UInt);
let weight_stride_2 = scope.create_local(Elem::UInt);
let weight_stride_3 = scope.create_local(Elem::UInt);
let in_channels = scope.create_local(Elem::UInt);
let weight_shape_1 = scope.create_local(Elem::UInt);
let kernel_size_0 = scope.create_local(Elem::UInt);
let kernel_size_1 = scope.create_local(Elem::UInt);
let weight_stride_0 = scope.create_local(u32::as_elem());
let weight_stride_1 = scope.create_local(u32::as_elem());
let weight_stride_2 = scope.create_local(u32::as_elem());
let weight_stride_3 = scope.create_local(u32::as_elem());
let in_channels = scope.create_local(u32::as_elem());
let weight_shape_1 = scope.create_local(u32::as_elem());
let kernel_size_0 = scope.create_local(u32::as_elem());
let kernel_size_1 = scope.create_local(u32::as_elem());
cpa!(scope, weight_stride_0 = stride(weight, 0u32));
cpa!(scope, weight_stride_1 = stride(weight, 1u32));
cpa!(scope, weight_stride_2 = stride(weight, 2u32));
@ -94,31 +95,31 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(u32::as_elem()));
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
cpa!(scope, stride_0_i = cast(conv_stride_0));
cpa!(scope, stride_1_i = cast(conv_stride_1));
let oc_out = scope.create_local(Elem::UInt);
let oc = scope.create_local(Elem::UInt);
let oc_out = scope.create_local(u32::as_elem());
let oc = scope.create_local(u32::as_elem());
let b = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
let k = scope.create_local(Elem::UInt);
let g = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let oh = scope.create_local(u32::as_elem());
let ow = scope.create_local(u32::as_elem());
let k = scope.create_local(u32::as_elem());
let g = scope.create_local(u32::as_elem());
let ic_start = scope.create_local(Elem::UInt);
let ic_end = scope.create_local(Elem::UInt);
let ic_tmp = scope.create_local(Elem::UInt);
let ic_start = scope.create_local(u32::as_elem());
let ic_end = scope.create_local(u32::as_elem());
let ic_tmp = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -141,20 +142,20 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, ic_start = g * ic_tmp);
cpa!(scope, ic_end = ic_start + ic_tmp);
let tmp_u = scope.create_local(Elem::UInt);
let tmp_u = scope.create_local(u32::as_elem());
let tmp_i = scope.create_local(Elem::Int(IntKind::I32));
let zero_i = scope.zero(Elem::Int(IntKind::I32));
let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32));
let kms_u = scope.create_local(Elem::UInt);
let kms_u = scope.create_local(u32::as_elem());
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ih_start = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);
let iw_end = scope.create_local(Elem::UInt);
let ih_start = scope.create_local(u32::as_elem());
let iw_start = scope.create_local(u32::as_elem());
let ih_end = scope.create_local(u32::as_elem());
let iw_end = scope.create_local(u32::as_elem());
cpa!(scope, kms_u = kernel_size_0 * dilation_0);
cpa!(scope, kms_0 = cast(kms_u));
@ -188,17 +189,17 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, tmp_u = cast(tmp_i));
cpa!(scope, iw_end = min(tmp_u, input_shape_3));
let index_input = scope.create_local(Elem::UInt);
let index_weight = scope.create_local(Elem::UInt);
let index_input = scope.create_local(u32::as_elem());
let index_weight = scope.create_local(u32::as_elem());
let index_input_b = scope.create_local(Elem::UInt);
let index_input_ic = scope.create_local(Elem::UInt);
let index_input_ih = scope.create_local(Elem::UInt);
let index_input_iw = scope.create_local(Elem::UInt);
let index_weight_ic = scope.create_local(Elem::UInt);
let index_weight_oc = scope.create_local(Elem::UInt);
let index_weight_kh = scope.create_local(Elem::UInt);
let index_weight_kw = scope.create_local(Elem::UInt);
let index_input_b = scope.create_local(u32::as_elem());
let index_input_ic = scope.create_local(u32::as_elem());
let index_input_ih = scope.create_local(u32::as_elem());
let index_input_iw = scope.create_local(u32::as_elem());
let index_weight_ic = scope.create_local(u32::as_elem());
let index_weight_oc = scope.create_local(u32::as_elem());
let index_weight_kh = scope.create_local(u32::as_elem());
let index_weight_kw = scope.create_local(u32::as_elem());
cpa!(scope, index_input_b = b * input_stride_0);
cpa!(scope, index_weight_oc = oc * weight_stride_1);
@ -208,15 +209,15 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
let sum = scope.create_local(output.item);
cpa!(scope, sum = bias[oc_out]);
let kh = scope.create_local(Elem::UInt);
let kw = scope.create_local(Elem::UInt);
let numerator_h_base = scope.create_local(Elem::UInt);
let numerator_h = scope.create_local(Elem::UInt);
let numerator_w_base = scope.create_local(Elem::UInt);
let numerator_w = scope.create_local(Elem::UInt);
let numerator_tmp = scope.create_local(Elem::UInt);
let numerator_mod = scope.create_local(Elem::UInt);
let zero = scope.zero(Elem::UInt);
let kh = scope.create_local(u32::as_elem());
let kw = scope.create_local(u32::as_elem());
let numerator_h_base = scope.create_local(u32::as_elem());
let numerator_h = scope.create_local(u32::as_elem());
let numerator_w_base = scope.create_local(u32::as_elem());
let numerator_w = scope.create_local(u32::as_elem());
let numerator_tmp = scope.create_local(u32::as_elem());
let numerator_mod = scope.create_local(u32::as_elem());
let zero = scope.zero(u32::as_elem());
let divisible = scope.create_local(Elem::Bool);
let not_neg = scope.create_local(Elem::Bool);
let cond = scope.create_local(Elem::Bool);
@ -324,7 +325,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: 7,
};

View File

@ -9,10 +9,7 @@ use cubecl::{
use crate::{
kernel::{
conv::{
batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col,
conv2d_implicit_gemm,
},
conv::{batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col},
prng::random_uniform,
},
tensor::JitTensor,
@ -42,7 +39,7 @@ pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
}
#[tune(
operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm),
operations(conv2d_direct, conv2d_im2col),
create_key = create_key,
should_run = should_run
)]
@ -111,7 +108,7 @@ fn should_run<R: JitRuntime, F: FloatElement, I: IntElement>(
op.options.groups,
out_h,
out_w,
&op.input.device,
&op.input.client,
),
_ => true,
}
@ -143,5 +140,6 @@ fn create_key<R: JitRuntime, E: FloatElement>(
width,
batch_size,
bias.is_some(),
E::dtype(),
))
}

View File

@ -91,6 +91,7 @@ fn create_key<R: JitRuntime, E: FloatElement>(
width,
batch_size,
bias.is_some(),
E::dtype(),
))
}

View File

@ -1,3 +1,4 @@
use burn_tensor::DType;
use cubecl::AutotuneKey;
use serde::{Deserialize, Serialize};
@ -20,6 +21,7 @@ pub struct Conv2dAutotuneKey {
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
pub dtype: DType,
}
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
@ -42,4 +44,5 @@ pub struct ConvTranspose2dAutotuneKey {
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
pub dtype: DType,
}

View File

@ -3,6 +3,7 @@ use cubecl::{
ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -43,16 +44,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
let output = self.output;
let idx = Variable::builtin(Builtin::AbsolutePos);
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_stride_4 = scope.create_local(Elem::UInt);
let input_shape_0 = scope.create_local(Elem::UInt);
let input_shape_1 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let input_shape_4 = scope.create_local(Elem::UInt);
let input_stride_0 = scope.create_local(u32::as_elem());
let input_stride_1 = scope.create_local(u32::as_elem());
let input_stride_2 = scope.create_local(u32::as_elem());
let input_stride_3 = scope.create_local(u32::as_elem());
let input_stride_4 = scope.create_local(u32::as_elem());
let input_shape_0 = scope.create_local(u32::as_elem());
let input_shape_1 = scope.create_local(u32::as_elem());
let input_shape_2 = scope.create_local(u32::as_elem());
let input_shape_3 = scope.create_local(u32::as_elem());
let input_shape_4 = scope.create_local(u32::as_elem());
cpa!(scope, input_stride_0 = stride(input, 0u32));
cpa!(scope, input_stride_1 = stride(input, 1u32));
cpa!(scope, input_stride_2 = stride(input, 2u32));
@ -64,16 +65,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, input_shape_3 = shape(input, 3u32));
cpa!(scope, input_shape_4 = shape(input, 4u32));
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_4 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_4 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_stride_4 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
let output_shape_4 = scope.create_local(u32::as_elem());
cpa!(scope, output_stride_0 = stride(output, 0u32));
cpa!(scope, output_stride_1 = stride(output, 1u32));
cpa!(scope, output_stride_2 = stride(output, 2u32));
@ -85,16 +86,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, output_shape_3 = shape(output, 3u32));
cpa!(scope, output_shape_4 = shape(output, 4u32));
let weight_stride_0 = scope.create_local(Elem::UInt);
let weight_stride_1 = scope.create_local(Elem::UInt);
let weight_stride_2 = scope.create_local(Elem::UInt);
let weight_stride_3 = scope.create_local(Elem::UInt);
let weight_stride_4 = scope.create_local(Elem::UInt);
let in_channels = scope.create_local(Elem::UInt);
let weight_shape_1 = scope.create_local(Elem::UInt);
let kernel_size_0 = scope.create_local(Elem::UInt);
let kernel_size_1 = scope.create_local(Elem::UInt);
let kernel_size_2 = scope.create_local(Elem::UInt);
let weight_stride_0 = scope.create_local(u32::as_elem());
let weight_stride_1 = scope.create_local(u32::as_elem());
let weight_stride_2 = scope.create_local(u32::as_elem());
let weight_stride_3 = scope.create_local(u32::as_elem());
let weight_stride_4 = scope.create_local(u32::as_elem());
let in_channels = scope.create_local(u32::as_elem());
let weight_shape_1 = scope.create_local(u32::as_elem());
let kernel_size_0 = scope.create_local(u32::as_elem());
let kernel_size_1 = scope.create_local(u32::as_elem());
let kernel_size_2 = scope.create_local(u32::as_elem());
cpa!(scope, weight_stride_0 = stride(weight, 0u32));
cpa!(scope, weight_stride_1 = stride(weight, 1u32));
cpa!(scope, weight_stride_2 = stride(weight, 2u32));
@ -106,16 +107,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(Elem::UInt));
let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(Elem::UInt));
let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(Elem::UInt));
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(u32::as_elem()));
let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(u32::as_elem()));
let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(u32::as_elem()));
let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(u32::as_elem()));
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
@ -124,19 +125,19 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, stride_1_i = cast(conv_stride_1));
cpa!(scope, stride_2_i = cast(conv_stride_2));
let oc_out = scope.create_local(Elem::UInt);
let oc = scope.create_local(Elem::UInt);
let oc_out = scope.create_local(u32::as_elem());
let oc = scope.create_local(u32::as_elem());
let b = scope.create_local(Elem::UInt);
let od = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
let k = scope.create_local(Elem::UInt);
let g = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let od = scope.create_local(u32::as_elem());
let oh = scope.create_local(u32::as_elem());
let ow = scope.create_local(u32::as_elem());
let k = scope.create_local(u32::as_elem());
let g = scope.create_local(u32::as_elem());
let ic_start = scope.create_local(Elem::UInt);
let ic_end = scope.create_local(Elem::UInt);
let ic_tmp = scope.create_local(Elem::UInt);
let ic_start = scope.create_local(u32::as_elem());
let ic_end = scope.create_local(u32::as_elem());
let ic_tmp = scope.create_local(u32::as_elem());
cpa!(scope, b = idx / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -162,24 +163,24 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, ic_start = g * ic_tmp);
cpa!(scope, ic_end = ic_start + ic_tmp);
let tmp_u = scope.create_local(Elem::UInt);
let tmp_u = scope.create_local(u32::as_elem());
let tmp_i = scope.create_local(Elem::Int(IntKind::I32));
let zero_i = scope.zero(Elem::Int(IntKind::I32));
let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32));
let kms_u = scope.create_local(Elem::UInt);
let kms_u = scope.create_local(u32::as_elem());
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
let kms_2 = scope.create_local(Elem::Int(IntKind::I32));
let id_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let id_start = scope.create_local(Elem::UInt);
let ih_start = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let id_end = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);
let iw_end = scope.create_local(Elem::UInt);
let id_start = scope.create_local(u32::as_elem());
let ih_start = scope.create_local(u32::as_elem());
let iw_start = scope.create_local(u32::as_elem());
let id_end = scope.create_local(u32::as_elem());
let ih_end = scope.create_local(u32::as_elem());
let iw_end = scope.create_local(u32::as_elem());
cpa!(scope, kms_u = kernel_size_0 * dilation_0);
cpa!(scope, kms_0 = cast(kms_u));
@ -228,19 +229,19 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, tmp_u = cast(tmp_i));
cpa!(scope, iw_end = min(tmp_u, input_shape_4));
let index_input = scope.create_local(Elem::UInt);
let index_weight = scope.create_local(Elem::UInt);
let index_input = scope.create_local(u32::as_elem());
let index_weight = scope.create_local(u32::as_elem());
let index_input_b = scope.create_local(Elem::UInt);
let index_input_ic = scope.create_local(Elem::UInt);
let index_input_id = scope.create_local(Elem::UInt);
let index_input_ih = scope.create_local(Elem::UInt);
let index_input_iw = scope.create_local(Elem::UInt);
let index_weight_ic = scope.create_local(Elem::UInt);
let index_weight_oc = scope.create_local(Elem::UInt);
let index_weight_kd = scope.create_local(Elem::UInt);
let index_weight_kh = scope.create_local(Elem::UInt);
let index_weight_kw = scope.create_local(Elem::UInt);
let index_input_b = scope.create_local(u32::as_elem());
let index_input_ic = scope.create_local(u32::as_elem());
let index_input_id = scope.create_local(u32::as_elem());
let index_input_ih = scope.create_local(u32::as_elem());
let index_input_iw = scope.create_local(u32::as_elem());
let index_weight_ic = scope.create_local(u32::as_elem());
let index_weight_oc = scope.create_local(u32::as_elem());
let index_weight_kd = scope.create_local(u32::as_elem());
let index_weight_kh = scope.create_local(u32::as_elem());
let index_weight_kw = scope.create_local(u32::as_elem());
cpa!(scope, index_input_b = b * input_stride_0);
cpa!(scope, index_weight_oc = oc * weight_stride_1);
@ -250,18 +251,18 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
let sum = scope.create_local(output.item);
cpa!(scope, sum = bias[oc_out]);
let kd = scope.create_local(Elem::UInt);
let kh = scope.create_local(Elem::UInt);
let kw = scope.create_local(Elem::UInt);
let numerator_d_base = scope.create_local(Elem::UInt);
let numerator_d = scope.create_local(Elem::UInt);
let numerator_h_base = scope.create_local(Elem::UInt);
let numerator_h = scope.create_local(Elem::UInt);
let numerator_w_base = scope.create_local(Elem::UInt);
let numerator_w = scope.create_local(Elem::UInt);
let numerator_tmp = scope.create_local(Elem::UInt);
let numerator_mod = scope.create_local(Elem::UInt);
let zero = scope.zero(Elem::UInt);
let kd = scope.create_local(u32::as_elem());
let kh = scope.create_local(u32::as_elem());
let kw = scope.create_local(u32::as_elem());
let numerator_d_base = scope.create_local(u32::as_elem());
let numerator_d = scope.create_local(u32::as_elem());
let numerator_h_base = scope.create_local(u32::as_elem());
let numerator_h = scope.create_local(u32::as_elem());
let numerator_w_base = scope.create_local(u32::as_elem());
let numerator_w = scope.create_local(u32::as_elem());
let numerator_tmp = scope.create_local(u32::as_elem());
let numerator_mod = scope.create_local(u32::as_elem());
let zero = scope.zero(u32::as_elem());
let divisible = scope.create_local(Elem::Bool);
let not_neg = scope.create_local(Elem::Bool);
let cond = scope.create_local(Elem::Bool);
@ -392,7 +393,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: 10,
};

View File

@ -5,7 +5,7 @@ use burn_tensor::{
use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch};
use crate::{
kernel::into_contiguous,
kernel::{cast, into_contiguous},
ops::{
numeric::{empty_device, ones_device, zeros_device},
reshape, swap_dims,
@ -426,7 +426,8 @@ fn compute_input_grad<R: JitRuntime, E: FloatElement>(
let [batch_size, in_channels, height, width] = input_shape.dims();
let (kernel_height, kernel_width) = kernel_dims;
let grad_in = zeros_device::<R, E>(
// Force `f32` to enable bitcasting as `u32`
let grad_in = zeros_device::<R, f32>(
client.clone(),
device.clone(),
Shape::new([batch_size, in_channels, height, width]),
@ -466,7 +467,7 @@ fn compute_input_grad<R: JitRuntime, E: FloatElement>(
use_mask,
);
grad_in
cast(grad_in)
}
#[derive(CubeLaunch)]
@ -564,19 +565,19 @@ fn deform_col2img_kernel<F: Float>(
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
float_atomic_add(&mut grad_input[gradient_pos], f32::cast_from(value));
}
}
}
}
#[cube]
fn float_atomic_add<F: Float>(ptr: &mut AtomicU32, value: F) {
if value != F::new(0.0) {
fn float_atomic_add(ptr: &mut AtomicU32, value: f32) {
if value != 0.0 {
let mut v = AtomicU32::load(ptr);
loop {
let prev = v;
let v_float = F::bitcast_from(v);
let v_float = f32::bitcast_from(v);
let new = u32::bitcast_from(v_float + value);
v = AtomicU32::compare_and_swap(ptr, v, new);
if prev == v {

View File

@ -5,6 +5,7 @@ use burn_tensor::ElementConversion;
use cubecl::{
cpa,
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -29,12 +30,12 @@ impl FlipComputeShader {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.create_local(Elem::UInt);
let offset_input = scope.zero(u32::as_elem());
let offset_local = scope.create_local(u32::as_elem());
let stride = scope.create_local(Elem::UInt);
let shape = scope.create_local(Elem::UInt);
let flip = scope.create_local(Elem::UInt);
let stride = scope.create_local(u32::as_elem());
let shape = scope.create_local(u32::as_elem());
let flip = scope.create_local(u32::as_elem());
let flip_bool = scope.create_local(Elem::Bool);
for i in 0..self.rank {
@ -44,7 +45,7 @@ impl FlipComputeShader {
scope,
flip = cast(Variable::new(
VariableKind::GlobalScalar(i as u16),
Item::new(Elem::UInt)
Item::new(u32::as_elem())
))
);
cpa!(scope, flip_bool = flip == 1u32);
@ -89,7 +90,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
visibility: Visibility::Read,
};
let flip_dims = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: self.rank,
};
let output = OutputInfo::Array { item };

View File

@ -1,7 +1,8 @@
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use cubecl::{
cpa,
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::CubePrimitive,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -28,12 +29,12 @@ impl RepeatComputeShader {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.zero(Elem::UInt);
let offset_input = scope.zero(u32::as_elem());
let offset_local = scope.zero(u32::as_elem());
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape = scope.create_local(Elem::UInt);
let stride_input = scope.create_local(u32::as_elem());
let stride_output = scope.create_local(u32::as_elem());
let shape = scope.create_local(u32::as_elem());
for i in 0..self.rank {
cpa!(scope, stride_input = stride(input, i));

View File

@ -2,15 +2,15 @@ use crate::{
element::JitElement,
kernel::{self},
tensor::JitTensor,
JitRuntime,
IntElement, JitRuntime,
};
use cubecl::prelude::*;
use cubecl::{calculate_cube_count_elemwise, CubeDim};
#[cube(launch_unchecked)]
fn scatter_kernel<T: Numeric>(
fn scatter_kernel<T: Numeric, I: Int>(
input: &mut Tensor<T>,
indices: &Tensor<i32>,
indices: &Tensor<I>,
value: &Tensor<T>,
dim: &u32,
) {
@ -65,7 +65,7 @@ fn scatter_kernel<T: Numeric>(
}
}
pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement>(
pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: IntElement>(
dim: usize,
tensor: JitTensor<R, E>,
indices: JitTensor<R, I>,
@ -105,7 +105,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement>(
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
unsafe {
scatter_kernel::launch_unchecked::<E, R>(
scatter_kernel::launch_unchecked::<E, I, R>(
&indices.client.clone(),
cube_count,
cube_dim,

View File

@ -4,7 +4,8 @@ use crate::{
use burn_tensor::{ElementConversion, Shape};
use cubecl::{
cpa,
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
ir::{Builtin, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -29,13 +30,13 @@ impl SliceComputeShader {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.create_local(Elem::UInt);
let offset_input = scope.zero(u32::as_elem());
let offset_local = scope.create_local(u32::as_elem());
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);
let range_start = scope.create_local(Elem::UInt);
let stride_input = scope.create_local(u32::as_elem());
let stride_output = scope.create_local(u32::as_elem());
let shape_output = scope.create_local(u32::as_elem());
let range_start = scope.create_local(u32::as_elem());
for i in 0..self.rank {
cpa!(scope, stride_input = stride(input, i));
@ -45,7 +46,7 @@ impl SliceComputeShader {
scope,
range_start = cast(Variable::new(
VariableKind::GlobalScalar(i as u16),
Item::new(Elem::UInt)
Item::new(u32::as_elem())
))
);
@ -85,7 +86,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
visibility: Visibility::Read,
};
let ranges = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: self.rank,
};
let output = OutputInfo::Array { item };

View File

@ -2,7 +2,8 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use burn_tensor::ElementConversion;
use cubecl::{
cpa,
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
ir::{Builtin, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
};
use std::{marker::PhantomData, ops::Range};
@ -26,18 +27,18 @@ impl SliceAssignComputeShader {
let value = self.value;
let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt);
let offset_value = scope.zero(Elem::UInt);
let offset_input = scope.zero(u32::as_elem());
let offset_value = scope.zero(u32::as_elem());
let offset_local = scope.create_local(Elem::UInt);
let offset_local_value = scope.create_local(Elem::UInt);
let offset_local_input = scope.create_local(Elem::UInt);
let offset_local = scope.create_local(u32::as_elem());
let offset_local_value = scope.create_local(u32::as_elem());
let offset_local_input = scope.create_local(u32::as_elem());
let stride_input = scope.create_local(Elem::UInt);
let stride_value = scope.create_local(Elem::UInt);
let shape_value = scope.create_local(Elem::UInt);
let shape_input = scope.create_local(Elem::UInt);
let range_start = scope.create_local(Elem::UInt);
let stride_input = scope.create_local(u32::as_elem());
let stride_value = scope.create_local(u32::as_elem());
let shape_value = scope.create_local(u32::as_elem());
let shape_input = scope.create_local(u32::as_elem());
let range_start = scope.create_local(u32::as_elem());
for i in 0..self.rank {
cpa!(scope, stride_input = stride(input, i));
@ -48,7 +49,7 @@ impl SliceAssignComputeShader {
scope,
range_start = cast(Variable::new(
VariableKind::GlobalScalar(i as u16),
Item::new(Elem::UInt)
Item::new(u32::as_elem())
))
);
@ -98,7 +99,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
visibility: Visibility::Read,
};
let ranges = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: self.rank,
};

View File

@ -1,6 +1,7 @@
use cubecl::{
cpa,
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -27,23 +28,23 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
let id = Variable::builtin(Builtin::AbsolutePos);
let elem = E::cube_elem();
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_stride_0 = scope.create_local(u32::as_elem());
let input_stride_1 = scope.create_local(u32::as_elem());
let input_stride_2 = scope.create_local(u32::as_elem());
let input_stride_3 = scope.create_local(u32::as_elem());
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(u32::as_elem());
let input_shape_3 = scope.create_local(u32::as_elem());
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, input_stride_0 = stride(input, 0u32));
cpa!(scope, input_stride_1 = stride(input, 1u32));
@ -63,10 +64,10 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let h = scope.create_local(Elem::UInt);
let w = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let c = scope.create_local(u32::as_elem());
let h = scope.create_local(u32::as_elem());
let w = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -80,23 +81,23 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
cpa!(scope, w = id / output_stride_3);
cpa!(scope, w = w % output_shape_3);
let input_height = scope.create_local(Elem::UInt);
let output_height = scope.create_local(Elem::UInt);
let input_height = scope.create_local(u32::as_elem());
let output_height = scope.create_local(u32::as_elem());
let output_height_float = scope.create_local(elem);
let input_width = scope.create_local(Elem::UInt);
let output_width = scope.create_local(Elem::UInt);
let input_width = scope.create_local(u32::as_elem());
let output_width = scope.create_local(u32::as_elem());
let output_width_float = scope.create_local(elem);
let frac = scope.create_local(elem);
let numerator = scope.create_local(Elem::UInt);
let numerator = scope.create_local(u32::as_elem());
let numerator_float = scope.create_local(elem);
let not_zero = scope.create_local(Elem::Bool);
let y_in_float = scope.create_local(elem);
let y_in = scope.create_local(Elem::UInt);
let y_in = scope.create_local(u32::as_elem());
let yw = scope.create_local(elem);
let y_tmp = scope.create_local(Elem::UInt);
let y_tmp = scope.create_local(u32::as_elem());
cpa!(scope, input_height = input_shape_2 - 1u32);
cpa!(scope, output_height = output_shape_2 - 1u32);
@ -109,7 +110,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
cpa!(scope, y_in = cast(y_in_float));
cpa!(scope, yw = frac - y_in_float);
let y0 = scope.zero(Elem::UInt);
let y0 = scope.zero(u32::as_elem());
cpa!(scope, not_zero = y_in != 0u32);
cpa!(scope, if(not_zero).then(|scope|{
cpa!(scope, y0 = y_in - 1u32);
@ -124,9 +125,9 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
let y3 = Self::min(scope, y_tmp, input_height);
let x_in_float = scope.create_local(elem);
let x_in = scope.create_local(Elem::UInt);
let x_in = scope.create_local(u32::as_elem());
let xw = scope.create_local(elem);
let x_tmp = scope.create_local(Elem::UInt);
let x_tmp = scope.create_local(u32::as_elem());
cpa!(scope, input_width = input_shape_3 - 1u32);
cpa!(scope, output_width = output_shape_3 - 1u32);
@ -139,7 +140,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
cpa!(scope, x_in = cast(x_in_float));
cpa!(scope, xw = frac - x_in_float);
let x0 = scope.zero(Elem::UInt);
let x0 = scope.zero(u32::as_elem());
cpa!(scope, not_zero = x_in != 0u32);
cpa!(scope, if(not_zero).then(|scope|{
cpa!(scope, x0 = x_in - 1u32);
@ -154,20 +155,20 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
cpa!(scope, x_tmp = x_in + 2u32);
let x3 = Self::min(scope, x_tmp, input_width);
let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index_base = scope.create_local(u32::as_elem());
let index_tmp = scope.create_local(u32::as_elem());
cpa!(scope, index_base = b * input_stride_0);
cpa!(scope, index_tmp = c * input_stride_1);
cpa!(scope, index_base += index_tmp);
let y0_stride = scope.create_local(Elem::UInt);
let y1_stride = scope.create_local(Elem::UInt);
let y2_stride = scope.create_local(Elem::UInt);
let y3_stride = scope.create_local(Elem::UInt);
let x0_stride = scope.create_local(Elem::UInt);
let x1_stride = scope.create_local(Elem::UInt);
let x2_stride = scope.create_local(Elem::UInt);
let x3_stride = scope.create_local(Elem::UInt);
let y0_stride = scope.create_local(u32::as_elem());
let y1_stride = scope.create_local(u32::as_elem());
let y2_stride = scope.create_local(u32::as_elem());
let y3_stride = scope.create_local(u32::as_elem());
let x0_stride = scope.create_local(u32::as_elem());
let x1_stride = scope.create_local(u32::as_elem());
let x2_stride = scope.create_local(u32::as_elem());
let x3_stride = scope.create_local(u32::as_elem());
cpa!(scope, y0_stride = y0 * input_stride_2);
cpa!(scope, y1_stride = y1 * input_stride_2);
cpa!(scope, y2_stride = y2 * input_stride_2);
@ -177,10 +178,10 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
cpa!(scope, x2_stride = x2 * input_stride_3);
cpa!(scope, x3_stride = x3 * input_stride_3);
let index_0 = scope.create_local(Elem::UInt);
let index_1 = scope.create_local(Elem::UInt);
let index_2 = scope.create_local(Elem::UInt);
let index_3 = scope.create_local(Elem::UInt);
let index_0 = scope.create_local(u32::as_elem());
let index_1 = scope.create_local(u32::as_elem());
let index_2 = scope.create_local(u32::as_elem());
let index_3 = scope.create_local(u32::as_elem());
let inp_0 = scope.create_local(input.item);
let inp_1 = scope.create_local(input.item);
let inp_2 = scope.create_local(input.item);

View File

@ -1,6 +1,7 @@
use cubecl::{
cpa,
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -25,23 +26,23 @@ impl InterpolateBilinearShader {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_stride_0 = scope.create_local(u32::as_elem());
let input_stride_1 = scope.create_local(u32::as_elem());
let input_stride_2 = scope.create_local(u32::as_elem());
let input_stride_3 = scope.create_local(u32::as_elem());
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(u32::as_elem());
let input_shape_3 = scope.create_local(u32::as_elem());
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, input_stride_0 = stride(input, 0u32));
cpa!(scope, input_stride_1 = stride(input, 1u32));
@ -61,10 +62,10 @@ impl InterpolateBilinearShader {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let h = scope.create_local(Elem::UInt);
let w = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let c = scope.create_local(u32::as_elem());
let h = scope.create_local(u32::as_elem());
let w = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -80,22 +81,22 @@ impl InterpolateBilinearShader {
let factor_float = scope.create_local(input.item);
let numerator_float = scope.create_local(input.item);
let numerator_int = scope.create_local(Elem::UInt);
let numerator_int = scope.create_local(u32::as_elem());
let denominator_float = scope.create_local(input.item);
let denominator_int = scope.create_local(Elem::UInt);
let denominator_int = scope.create_local(u32::as_elem());
let frac = scope.create_local(input.item);
let v0 = scope.create_local(input.item);
let v1 = scope.create_local(input.item);
let one = scope.create_with_value(1f32, input.item);
let y0 = scope.create_local(Elem::UInt);
let y1 = scope.create_local(Elem::UInt);
let y0 = scope.create_local(u32::as_elem());
let y1 = scope.create_local(u32::as_elem());
let yw = scope.create_local(input.item);
let yw_ = scope.create_local(input.item);
let x0 = scope.create_local(Elem::UInt);
let x1 = scope.create_local(Elem::UInt);
let x0 = scope.create_local(u32::as_elem());
let x1 = scope.create_local(u32::as_elem());
let xw = scope.create_local(input.item);
let xw_ = scope.create_local(input.item);
@ -129,13 +130,13 @@ impl InterpolateBilinearShader {
cpa!(scope, x0 = cast(v0));
cpa!(scope, x1 = cast(v1));
let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index = scope.create_local(Elem::UInt);
let y0_stride = scope.create_local(Elem::UInt);
let y1_stride = scope.create_local(Elem::UInt);
let x0_stride = scope.create_local(Elem::UInt);
let x1_stride = scope.create_local(Elem::UInt);
let index_base = scope.create_local(u32::as_elem());
let index_tmp = scope.create_local(u32::as_elem());
let index = scope.create_local(u32::as_elem());
let y0_stride = scope.create_local(u32::as_elem());
let y1_stride = scope.create_local(u32::as_elem());
let x0_stride = scope.create_local(u32::as_elem());
let x1_stride = scope.create_local(u32::as_elem());
let p_a = scope.create_local(input.item);
let p_b = scope.create_local(input.item);
let p_c = scope.create_local(input.item);

View File

@ -1,6 +1,7 @@
use cubecl::{
cpa,
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -27,23 +28,23 @@ impl<E: JitElement> InterpolateNearestShader<E> {
let id = Variable::builtin(Builtin::AbsolutePos);
let elem = E::cube_elem();
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
let input_stride_2 = scope.create_local(Elem::UInt);
let input_stride_3 = scope.create_local(Elem::UInt);
let input_stride_0 = scope.create_local(u32::as_elem());
let input_stride_1 = scope.create_local(u32::as_elem());
let input_stride_2 = scope.create_local(u32::as_elem());
let input_stride_3 = scope.create_local(u32::as_elem());
let input_shape_2 = scope.create_local(Elem::UInt);
let input_shape_3 = scope.create_local(Elem::UInt);
let input_shape_2 = scope.create_local(u32::as_elem());
let input_shape_3 = scope.create_local(u32::as_elem());
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, input_stride_0 = stride(input, 0u32));
cpa!(scope, input_stride_1 = stride(input, 1u32));
@ -63,10 +64,10 @@ impl<E: JitElement> InterpolateNearestShader<E> {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let h = scope.create_local(Elem::UInt);
let w = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let c = scope.create_local(u32::as_elem());
let h = scope.create_local(u32::as_elem());
let w = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -85,8 +86,8 @@ impl<E: JitElement> InterpolateNearestShader<E> {
let denominator_float = scope.create_local(elem);
let x = scope.create_local(elem);
let y = scope.create_local(elem);
let xu = scope.create_local(Elem::UInt);
let yu = scope.create_local(Elem::UInt);
let xu = scope.create_local(u32::as_elem());
let yu = scope.create_local(u32::as_elem());
cpa!(scope, factor_float = cast(h));
cpa!(scope, numerator_float = cast(input_shape_2));
@ -104,8 +105,8 @@ impl<E: JitElement> InterpolateNearestShader<E> {
cpa!(scope, x = floor(x));
cpa!(scope, xu = cast(x));
let index = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index = scope.create_local(u32::as_elem());
let index_tmp = scope.create_local(u32::as_elem());
let val = scope.create_local(output.item);
cpa!(scope, index = b * input_stride_0);

View File

@ -1,6 +1,7 @@
use cubecl::{
cpa,
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -26,25 +27,25 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);
let grad_stride_0 = scope.create_local(u32::as_elem());
let grad_stride_1 = scope.create_local(u32::as_elem());
let grad_stride_2 = scope.create_local(u32::as_elem());
let grad_stride_3 = scope.create_local(u32::as_elem());
let grad_shape_0 = scope.create_local(Elem::UInt);
let grad_shape_1 = scope.create_local(Elem::UInt);
let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);
let grad_shape_0 = scope.create_local(u32::as_elem());
let grad_shape_1 = scope.create_local(u32::as_elem());
let grad_shape_2 = scope.create_local(u32::as_elem());
let grad_shape_3 = scope.create_local(u32::as_elem());
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, grad_stride_0 = stride(grad, 0u32));
cpa!(scope, grad_stride_1 = stride(grad, 1u32));
@ -66,10 +67,10 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let oh = scope.create_local(Elem::UInt);
let ow = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let c = scope.create_local(u32::as_elem());
let oh = scope.create_local(u32::as_elem());
let ow = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -90,11 +91,11 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
let result = scope.create_local(grad.item);
let index_grad = scope.create_local(Elem::UInt);
let index_grad_0 = scope.create_local(Elem::UInt);
let index_grad_1 = scope.create_local(Elem::UInt);
let index_grad_2 = scope.create_local(Elem::UInt);
let index_grad_3 = scope.create_local(Elem::UInt);
let index_grad = scope.create_local(u32::as_elem());
let index_grad_0 = scope.create_local(u32::as_elem());
let index_grad_1 = scope.create_local(u32::as_elem());
let index_grad_2 = scope.create_local(u32::as_elem());
let index_grad_3 = scope.create_local(u32::as_elem());
cpa!(scope, index_grad_0 = b * grad_stride_0);
cpa!(scope, index_grad_1 = c * grad_stride_1);
@ -135,7 +136,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
let elem = E::cube_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
let index = scope.create_local(u32::as_elem());
cpa!(scope, index = input_index * output_size);
cpa!(scope, numerator_float = cast(index));
@ -156,9 +157,9 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
let elem = E::cube_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
let index = scope.create_local(u32::as_elem());
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);
let end_index = scope.create_local(u32::as_elem());
cpa!(scope, index = input_index + 1u32);
cpa!(scope, index *= output_size);

View File

@ -23,7 +23,7 @@ pub struct MatmulAutotuneOperationSet<R: JitRuntime, E: FloatElement> {
impl<R: JitRuntime, E: FloatElement> MatmulAutotuneOperationSet<R, E> {
fn new(lhs: JitTensor<R, E>, rhs: JitTensor<R, E>, out: JitTensor<R, E>) -> Self {
Self {
key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)),
key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())),
lhs,
rhs,
out,

View File

@ -1,5 +1,5 @@
use crate::tune::anchor;
use burn_tensor::Shape;
use burn_tensor::{DType, Shape};
use core::fmt::Debug;
use serde::{Deserialize, Serialize};
use std::{cmp::max, fmt::Display, hash::Hash};
@ -13,19 +13,21 @@ pub struct MatmulAutotuneKey {
anchored_k: usize,
anchored_n: usize,
anchored_batch: usize,
dtype: DType,
}
impl Display for MatmulAutotuneKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}",
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?} dtype:{:?}",
self.round,
self.broadcast,
self.anchored_m,
self.anchored_k,
self.anchored_n,
self.anchored_batch
self.anchored_batch,
self.dtype
)
.as_str(),
)
@ -34,7 +36,7 @@ impl Display for MatmulAutotuneKey {
impl MatmulAutotuneKey {
/// Create a matmul autotune key from the input shapes
pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self {
pub fn new(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
let ndims = lhs_shape.num_dims();
let m = lhs_shape.dims[ndims - 2];
let k = lhs_shape.dims[ndims - 1];
@ -62,6 +64,7 @@ impl MatmulAutotuneKey {
anchored_k: anchor(k, None),
anchored_n: anchor(n, None),
anchored_batch: anchor(batch_product, Some(256)),
dtype,
}
}
}
@ -74,7 +77,7 @@ mod tests {
fn matmul_autotune_key_all_same_and_round() {
let lhs_shape: Shape = [4, 512, 512].into();
let rhs_shape: Shape = [4, 512, 512].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
assert!(key.round);
assert!(!key.broadcast);
@ -87,7 +90,7 @@ mod tests {
fn matmul_autotune_key_all_different() {
let lhs_shape: Shape = [2, 3, 511, 512].into();
let rhs_shape: Shape = [3, 2, 512, 513].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
assert!(!key.round);
assert!(key.broadcast);
@ -101,7 +104,7 @@ mod tests {
fn matmul_autotune_key_large_batch() {
let lhs_shape: Shape = [128, 512, 511, 512].into();
let rhs_shape: Shape = [200, 400, 512, 513].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
assert!(key.anchored_batch == 256);
}

View File

@ -10,6 +10,7 @@ use cubecl::{
ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -36,23 +37,23 @@ impl AvgPool2dBackwardComputeShader {
let output = self.output;
let id = Variable::builtin(Builtin::AbsolutePos);
let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);
let grad_stride_0 = scope.create_local(u32::as_elem());
let grad_stride_1 = scope.create_local(u32::as_elem());
let grad_stride_2 = scope.create_local(u32::as_elem());
let grad_stride_3 = scope.create_local(u32::as_elem());
let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);
let grad_shape_2 = scope.create_local(u32::as_elem());
let grad_shape_3 = scope.create_local(u32::as_elem());
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, grad_stride_0 = stride(grad, 0u32));
cpa!(scope, grad_stride_1 = stride(grad, 1u32));
@ -72,16 +73,16 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let c = scope.create_local(u32::as_elem());
let ih = scope.create_local(u32::as_elem());
let iw = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -95,16 +96,16 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, iw = id / output_stride_3);
cpa!(scope, iw = iw % output_shape_3);
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
let index_current = scope.create_local(u32::as_elem());
let index_current_tmp = scope.create_local(u32::as_elem());
cpa!(scope, index_current = ih * output_stride_2);
cpa!(scope, index_current_tmp = iw * output_stride_3);
cpa!(scope, index_current += index_current_tmp);
let index = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index_base = scope.create_local(Elem::UInt);
let index = scope.create_local(u32::as_elem());
let index_tmp = scope.create_local(u32::as_elem());
let index_base = scope.create_local(u32::as_elem());
let grad_accumulation = scope.zero(grad.item);
let result = scope.create_local(grad.item);
@ -130,14 +131,14 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, index_tmp = c * grad_stride_1);
cpa!(scope, index_base += index_tmp);
let border_bottom = scope.create_local(Elem::UInt);
let border_right = scope.create_local(Elem::UInt);
let begin_h = scope.create_local(Elem::UInt);
let begin_w = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let iw_end = scope.create_local(Elem::UInt);
let ih_start = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);
let border_bottom = scope.create_local(u32::as_elem());
let border_right = scope.create_local(u32::as_elem());
let begin_h = scope.create_local(u32::as_elem());
let begin_w = scope.create_local(u32::as_elem());
let iw_start = scope.create_local(u32::as_elem());
let iw_end = scope.create_local(u32::as_elem());
let ih_start = scope.create_local(u32::as_elem());
let ih_end = scope.create_local(u32::as_elem());
let after_start = scope.create_local(Elem::Bool);
let before_end = scope.create_local(Elem::Bool);
let contributed_h = scope.create_local(Elem::Bool);
@ -147,9 +148,9 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, begin_h = ih + padding_0);
cpa!(scope, begin_w = iw + padding_1);
let ih_diff = scope.create_local(Elem::UInt);
let iw_diff = scope.create_local(Elem::UInt);
let count_int = scope.create_local(Elem::UInt);
let ih_diff = scope.create_local(u32::as_elem());
let iw_diff = scope.create_local(u32::as_elem());
let count_int = scope.create_local(u32::as_elem());
cpa!(
scope,
@ -216,12 +217,12 @@ impl AvgPool2dBackwardComputeShader {
output_stride_2: Variable,
output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
let [kernel_size_0, kernel_size_1] = self.kernel_size;
@ -273,8 +274,8 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
cpa!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
let oh_start = scope.create_local(Elem::UInt);
let ow_start = scope.create_local(Elem::UInt);
let oh_start = scope.create_local(u32::as_elem());
let ow_start = scope.create_local(u32::as_elem());
cpa!(scope, oh_start = cast(oh_start_tmp));
cpa!(scope, ow_start = cast(ow_start_tmp));
@ -285,11 +286,11 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, oh_end_tmp = max(kms_0, 0i32));
cpa!(scope, ow_end_tmp = max(kms_1, 0i32));
let oh_end = scope.create_local(Elem::UInt);
let ow_end = scope.create_local(Elem::UInt);
let oh_end = scope.create_local(u32::as_elem());
let ow_end = scope.create_local(u32::as_elem());
let oh_end_limit = scope.create_local(Elem::UInt);
let ow_end_limit = scope.create_local(Elem::UInt);
let oh_end_limit = scope.create_local(u32::as_elem());
let ow_end_limit = scope.create_local(u32::as_elem());
cpa!(scope, oh_end = cast(oh_end_tmp));
cpa!(scope, ow_end = cast(ow_end_tmp));
@ -303,8 +304,8 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, oh_end = min(oh_end, oh_end_limit));
cpa!(scope, ow_end = min(ow_end, ow_end_limit));
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
let index_current = scope.create_local(u32::as_elem());
let index_current_tmp = scope.create_local(u32::as_elem());
cpa!(scope, index_current = ih * output_stride_2);
cpa!(scope, index_current_tmp = iw * output_stride_3);
@ -340,7 +341,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: 6,
};
let output = OutputInfo::Array { item };

View File

@ -10,6 +10,7 @@ use cubecl::{
ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -36,23 +37,23 @@ impl MaxPool2dBackwardComputeShader {
let indices = self.indices;
let id = Variable::builtin(Builtin::AbsolutePos);
let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);
let grad_stride_0 = scope.create_local(u32::as_elem());
let grad_stride_1 = scope.create_local(u32::as_elem());
let grad_stride_2 = scope.create_local(u32::as_elem());
let grad_stride_3 = scope.create_local(u32::as_elem());
let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);
let grad_shape_2 = scope.create_local(u32::as_elem());
let grad_shape_3 = scope.create_local(u32::as_elem());
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(u32::as_elem());
let output_stride_1 = scope.create_local(u32::as_elem());
let output_stride_2 = scope.create_local(u32::as_elem());
let output_stride_3 = scope.create_local(u32::as_elem());
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(u32::as_elem());
let output_shape_1 = scope.create_local(u32::as_elem());
let output_shape_2 = scope.create_local(u32::as_elem());
let output_shape_3 = scope.create_local(u32::as_elem());
cpa!(scope, grad_stride_0 = stride(grad, 0u32));
cpa!(scope, grad_stride_1 = stride(grad, 1u32));
@ -72,10 +73,10 @@ impl MaxPool2dBackwardComputeShader {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
let c = scope.create_local(u32::as_elem());
let ih = scope.create_local(u32::as_elem());
let iw = scope.create_local(u32::as_elem());
cpa!(scope, b = id / output_stride_0);
cpa!(scope, b = b % output_shape_0);
@ -89,8 +90,8 @@ impl MaxPool2dBackwardComputeShader {
cpa!(scope, iw = id / output_stride_3);
cpa!(scope, iw = iw % output_shape_3);
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
let index_current = scope.create_local(u32::as_elem());
let index_current_tmp = scope.create_local(u32::as_elem());
cpa!(scope, index_current = ih * output_stride_2);
cpa!(scope, index_current_tmp = iw * output_stride_3);
@ -98,12 +99,12 @@ impl MaxPool2dBackwardComputeShader {
let index_select = scope.create_local(Elem::Int(IntKind::I32));
let index_max = scope.create_local(Elem::UInt);
let index_max = scope.create_local(u32::as_elem());
let is_max = scope.create_local(Elem::Bool);
let index = scope.create_local(Elem::UInt);
let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index = scope.create_local(u32::as_elem());
let index_base = scope.create_local(u32::as_elem());
let index_tmp = scope.create_local(u32::as_elem());
let grad_accumulation = scope.zero(grad.item);
let result = scope.create_local(grad.item);
@ -162,12 +163,12 @@ impl MaxPool2dBackwardComputeShader {
output_stride_2: Variable,
output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
let [kernel_size_0, kernel_size_1] = self.kernel_size;
@ -219,8 +220,8 @@ impl MaxPool2dBackwardComputeShader {
cpa!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
cpa!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
let oh_start = scope.create_local(Elem::UInt);
let ow_start = scope.create_local(Elem::UInt);
let oh_start = scope.create_local(u32::as_elem());
let ow_start = scope.create_local(u32::as_elem());
cpa!(scope, oh_start = cast(oh_start_tmp));
cpa!(scope, ow_start = cast(ow_start_tmp));
@ -231,11 +232,11 @@ impl MaxPool2dBackwardComputeShader {
cpa!(scope, oh_end_tmp = max(kms_0, 0i32));
cpa!(scope, ow_end_tmp = max(kms_1, 0i32));
let oh_end = scope.create_local(Elem::UInt);
let ow_end = scope.create_local(Elem::UInt);
let oh_end = scope.create_local(u32::as_elem());
let ow_end = scope.create_local(u32::as_elem());
let oh_end_limit = scope.create_local(Elem::UInt);
let ow_end_limit = scope.create_local(Elem::UInt);
let oh_end_limit = scope.create_local(u32::as_elem());
let ow_end_limit = scope.create_local(u32::as_elem());
cpa!(scope, oh_end = cast(oh_end_tmp));
cpa!(scope, ow_end = cast(ow_end_tmp));
@ -249,8 +250,8 @@ impl MaxPool2dBackwardComputeShader {
cpa!(scope, oh_end = min(oh_end, oh_end_limit));
cpa!(scope, ow_end = min(ow_end, ow_end_limit));
let index_current = scope.create_local(Elem::UInt);
let index_current_tmp = scope.create_local(Elem::UInt);
let index_current = scope.create_local(u32::as_elem());
let index_current_tmp = scope.create_local(u32::as_elem());
cpa!(scope, index_current = ih * output_stride_2);
cpa!(scope, index_current_tmp = iw * output_stride_3);
@ -295,7 +296,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: 6,
};
let output = OutputInfo::Array { item };

View File

@ -1,6 +1,6 @@
use cubecl::{
cpa,
ir::{Builtin, Elem, Item, Scope, Variable, VariableKind},
ir::{Builtin, Item, Scope, Variable, VariableKind},
prelude::*,
CubeCountSettings, Execution, InputInfo, OutputInfo,
};
@ -60,10 +60,10 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
let seed0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let seed1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let seed2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
let seed3 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let seed0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
let seed1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
let seed2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
let seed3 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
let seeds = [seed0, seed1, seed2, seed3];
let mut args = Vec::<Variable>::new();
@ -80,7 +80,7 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
size: P::args_length(),
};
let seeds = InputInfo::Scalar {
elem: Elem::UInt,
elem: u32::as_elem(),
size: 4,
};
let out = OutputInfo::Array { item };
@ -166,40 +166,40 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
let cube_count_y = Variable::builtin(Builtin::CubeCountY);
let local_index = Variable::builtin(Builtin::UnitPos);
let n_invocations = scope.create_local(Elem::UInt);
let n_invocations = scope.create_local(u32::as_elem());
cpa!(scope, n_invocations = cube_dim_x);
cpa!(scope, n_invocations *= cube_dim_y);
let cube_offset = scope.create_local(Elem::UInt);
let cube_offset = scope.create_local(u32::as_elem());
cpa!(scope, cube_offset = cube_pos_x * cube_count_y);
cpa!(scope, cube_offset += cube_pos_y);
cpa!(scope, cube_offset *= n_invocations);
let write_index_base = scope.create_local(Elem::UInt);
let write_index_base = scope.create_local(u32::as_elem());
cpa!(scope, write_index_base = cube_offset);
cpa!(scope, write_index_base *= n_values_per_thread);
cpa!(scope, write_index_base += local_index);
// Set state with unique seeds
let thread_seed = scope.create_local(Elem::UInt);
let thread_seed = scope.create_local(u32::as_elem());
cpa!(scope, thread_seed = cast(1000000007));
let thread_seed_index = scope.create_local(Elem::UInt);
let thread_seed_index = scope.create_local(u32::as_elem());
cpa!(scope, thread_seed_index = cube_offset + local_index);
cpa!(scope, thread_seed *= thread_seed_index);
let state_0 = scope.create_local(Elem::UInt);
let state_0 = scope.create_local(u32::as_elem());
cpa!(scope, state_0 = thread_seed);
cpa!(scope, state_0 += seed_0);
let state_1 = scope.create_local(Elem::UInt);
let state_1 = scope.create_local(u32::as_elem());
cpa!(scope, state_1 = thread_seed);
cpa!(scope, state_1 += seed_1);
let state_2 = scope.create_local(Elem::UInt);
let state_2 = scope.create_local(u32::as_elem());
cpa!(scope, state_2 = thread_seed);
cpa!(scope, state_2 += seed_2);
let state_3 = scope.create_local(Elem::UInt);
let state_3 = scope.create_local(u32::as_elem());
cpa!(scope, state_3 = thread_seed);
cpa!(scope, state_3 += seed_3);
@ -260,7 +260,7 @@ fn taus_step(
s3: Variable,
m: Variable,
) {
let b = scope.create_local(Elem::UInt);
let b = scope.create_local(u32::as_elem());
cpa!(scope, b = z << s1);
cpa!(scope, b = b ^ z);
cpa!(scope, b = b >> s2);

View File

@ -2,6 +2,7 @@ use burn_tensor::Shape;
use cubecl::{
cpa,
ir::{Elem, FloatKind, Scope, Variable},
prelude::*,
};
use crate::{
@ -34,7 +35,14 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
output: Variable,
) {
let float_elem = Elem::Float(FloatKind::F32);
let prob = args[0];
let mut prob = args[0];
if prob.item.elem() != Elem::Float(FloatKind::F32) {
let prob_f32 = scope.create_local(float_elem);
cpa!(scope, prob_f32 = cast(prob));
prob = prob_f32;
}
cpa!(
scope,
range(0u32, n_values_per_thread).for_each(|i, scope| {
@ -43,7 +51,7 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
taus_step_2(scope, state_2);
lcg_step(scope, state_3);
let int_random = scope.create_local(Elem::UInt);
let int_random = scope.create_local(u32::as_elem());
cpa!(scope, int_random = state_0 ^ state_1);
cpa!(scope, int_random = int_random ^ state_2);
cpa!(scope, int_random = int_random ^ state_3);
@ -54,7 +62,7 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
let bernoulli = scope.create_local(Elem::Bool);
cpa!(scope, bernoulli = float_random < prob);
let write_index = scope.create_local(Elem::UInt);
let write_index = scope.create_local(u32::as_elem());
cpa!(scope, write_index = i * n_invocations);
cpa!(scope, write_index += write_index_base);
cpa!(scope, output[write_index] = bernoulli);

View File

@ -1,6 +1,7 @@
use cubecl::{
cpa,
ir::{Elem, FloatKind, Scope, Variable},
prelude::*,
};
use std::f32::consts::PI;
@ -47,7 +48,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
cpa!(
scope,
range(0u32, n_values_per_thread / 2).for_each(|i, scope| {
let int_random = scope.create_local(Elem::UInt);
let int_random = scope.create_local(u32::as_elem());
// First random uniform integer
taus_step_0(scope, state_0);
@ -95,9 +96,9 @@ impl<E: JitElement> Prng<E> for Normal<E> {
cpa!(scope, normal_1 += mean);
// Write to output
let write_index_0 = scope.create_local(Elem::UInt);
let write_index_1 = scope.create_local(Elem::UInt);
let iteration_offset = scope.create_local(Elem::UInt);
let write_index_0 = scope.create_local(u32::as_elem());
let write_index_1 = scope.create_local(u32::as_elem());
let iteration_offset = scope.create_local(u32::as_elem());
cpa!(scope, write_index_0 = write_index_base);
cpa!(scope, iteration_offset = two * i);
cpa!(scope, iteration_offset *= n_invocations);

View File

@ -2,6 +2,7 @@ use burn_tensor::Shape;
use cubecl::{
cpa,
ir::{Elem, FloatKind, Scope, Variable},
prelude::*,
};
use crate::{
@ -49,7 +50,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
taus_step_2(scope, state_2);
lcg_step(scope, state_3);
let int_random = scope.create_local(Elem::UInt);
let int_random = scope.create_local(u32::as_elem());
cpa!(scope, int_random = state_0 ^ state_1);
cpa!(scope, int_random = int_random ^ state_2);
cpa!(scope, int_random = int_random ^ state_3);
@ -65,7 +66,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
cpa!(scope, uniform = cast(uniform_float));
cpa!(scope, uniform += lower_bound);
let write_index = scope.create_local(Elem::UInt);
let write_index = scope.create_local(u32::as_elem());
cpa!(scope, write_index = i * n_invocations);
cpa!(scope, write_index += write_index_base);
cpa!(scope, output[write_index] = uniform);

View File

@ -3,6 +3,7 @@ use burn_tensor::cast::ToElement;
use cubecl::{
cpa,
ir::{Elem, Item, Scope, Variable},
prelude::*,
};
use super::base::ReduceDimShared;
@ -17,7 +18,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
input_item: Item,
) -> Self::Accumulator {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
let index_shared_memory = scope.create_shared(u32::as_elem(), shared_memory_size);
let max = input_item
.elem()
.constant_from_f64(ToElement::to_f64(&E::minimum_value()));

View File

@ -2,6 +2,7 @@ use burn_tensor::cast::ToElement;
use cubecl::{
cpa,
ir::{Elem, Item, Scope, Variable},
prelude::*,
};
use crate::{kernel::reduce::Argmin, JitElement};
@ -18,7 +19,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
input_item: Item,
) -> Self::Accumulator {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
let index_shared_memory = scope.create_shared(u32::as_elem(), shared_memory_size);
let min = input_item
.elem()
.constant_from_f64(ToElement::to_f64(&E::maximum_value()));

View File

@ -2,6 +2,7 @@ use cubecl::{
cpa,
ir::{Builtin, KernelDefinition, VariableKind},
prelude::CubeCount,
prelude::*,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
@ -120,38 +121,38 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
let cube_dim_x = Variable::builtin(Builtin::CubeDimX);
let cube_dim_y = Variable::builtin(Builtin::CubeDimY);
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
let stride_reduce_dim_input = scope.create_local(u32::as_elem());
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
let shape_reduce_dim_input = scope.create_local(Elem::UInt);
let shape_reduce_dim_input = scope.create_local(u32::as_elem());
cpa!(scope, shape_reduce_dim_input = shape(tensor, dim));
// To determine which reduce_group (not position, but absolute id)
let reduce_group_id = scope.create_local(Elem::UInt);
let reduce_group_id = scope.create_local(u32::as_elem());
cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x);
cpa!(scope, reduce_group_id += cube_pos_x);
// nth thread in the cube
let local_id = scope.create_local(Elem::UInt);
let local_id = scope.create_local(u32::as_elem());
cpa!(scope, local_id = local_invocation_id_y * cube_dim_x);
cpa!(scope, local_id += local_invocation_id_x);
let n_threads = scope.create_local(Elem::UInt);
let n_threads = scope.create_local(u32::as_elem());
cpa!(scope, n_threads = cube_dim_x * cube_dim_y);
let index_offset = scope.zero(Elem::UInt);
let index_offset = scope.zero(u32::as_elem());
cpa!(
scope,
range(0u32, rank).for_each(|i, scope| {
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);
let stride_input = scope.create_local(u32::as_elem());
let stride_output = scope.create_local(u32::as_elem());
let shape_output = scope.create_local(u32::as_elem());
cpa!(scope, stride_input = stride(tensor, i));
cpa!(scope, stride_output = stride(output, i));
cpa!(scope, shape_output = shape(output, i));
let num_block = scope.create_local(Elem::UInt);
let num_block = scope.create_local(u32::as_elem());
cpa!(scope, num_block = reduce_group_id / stride_output);
cpa!(scope, num_block = num_block % shape_output);
cpa!(scope, num_block = num_block * stride_input);
@ -166,14 +167,14 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
cpa!(
scope,
range(0u32, self.n_input_values_per_thread).for_each(|i, scope| {
let nth = scope.create_local(Elem::UInt);
let nth = scope.create_local(u32::as_elem());
cpa!(scope, nth = i * n_threads);
cpa!(scope, nth += local_id);
let within_shape = scope.create_local(Elem::Bool);
if self.divisible_shape {
let current_position = scope.create_local(Elem::UInt);
let current_position = scope.create_local(u32::as_elem());
cpa!(scope, current_position = nth * stride_reduce_dim_input);
cpa!(scope, current_position += index_offset);
@ -182,7 +183,7 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
} else {
cpa!(scope, within_shape = nth < shape_reduce_dim_input);
cpa!(scope, if(within_shape).then(|scope|{
let current_position = scope.create_local(Elem::UInt);
let current_position = scope.create_local(u32::as_elem());
cpa!(scope, current_position = nth * stride_reduce_dim_input);
cpa!(scope, current_position += index_offset);
@ -208,7 +209,7 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
let updating_thread = scope.create_local(Elem::Bool);
cpa!(scope, updating_thread = local_id < n_threads);
cpa!(scope, if(updating_thread).then(|scope|{
let read_position = scope.create_local(Elem::UInt);
let read_position = scope.create_local(u32::as_elem());
cpa!(scope, read_position = n_threads + local_id);
let read_value = RD::read_from_shared(scope, shared_memory, read_position);

View File

@ -44,6 +44,7 @@ impl<RD: ReduceDimAlgorithm<EI>, R: JitRuntime, EI: JitElement, EO: JitElement>
&input.shape,
&input.strides,
reduce_dim,
EI::dtype(),
)),
input,
output,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::{cmp::min, fmt::Display};
use burn_tensor::Shape;
use burn_tensor::{DType, Shape};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
/// Autotune key representative of reduce versions
@ -9,14 +9,15 @@ pub struct ReduceAutotuneKey {
pub(crate) reduce_dim_length: usize,
pub(crate) reduce_dim_stride: usize,
pub(crate) others_product: usize,
dtype: DType,
}
impl Display for ReduceAutotuneKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"Reduce - reduce_dim_length: {:?} reduce_dim_stride: {:?} others_product: {:?}",
self.reduce_dim_length, self.reduce_dim_stride, self.others_product
"Reduce - reduce_dim_length: {:?} reduce_dim_stride: {:?} others_product: {:?} dtype: {:?}",
self.reduce_dim_length, self.reduce_dim_stride, self.others_product, self.dtype
)
.as_str(),
)
@ -25,7 +26,7 @@ impl Display for ReduceAutotuneKey {
impl ReduceAutotuneKey {
/// Create a reduce autotune key from the input shape and reduce dim
pub fn new(shape: &Shape, strides: &[usize], reduce_dim: usize) -> Self {
pub fn new(shape: &Shape, strides: &[usize], reduce_dim: usize, dtype: DType) -> Self {
let ndims = strides.len();
let reduce_dim_length = shape.dims[reduce_dim];
let reduce_dim_stride = strides[reduce_dim];
@ -39,6 +40,7 @@ impl ReduceAutotuneKey {
reduce_dim_length: anchor(reduce_dim_length, None),
reduce_dim_stride: anchor(reduce_dim_stride, None),
others_product: anchor(others_product, None),
dtype,
}
}
}

View File

@ -21,6 +21,14 @@ pub(crate) async fn into_data<R: JitRuntime, E: JitElement>(tensor: JitTensor<R,
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}
#[allow(unused, reason = "useful for debugging kernels")]
pub(crate) fn into_data_sync<R: JitRuntime, E: JitElement>(tensor: JitTensor<R, E>) -> TensorData {
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read(tensor.handle.binding());
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}
pub(crate) async fn bool_into_data<R: JitRuntime>(tensor: JitTensor<R, u32>) -> TensorData {
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read_async(tensor.handle.binding()).await;

View File

@ -37,8 +37,12 @@ pub use serial_test;
#[macro_export]
macro_rules! testgen_all {
() => {
use burn_tensor::{Float, Int};
$crate::testgen_all!([Float], [Int]);
};
([$($float:ident),*], [$($int:ident),*]) => {
mod jit {
burn_jit::testgen_jit!();
burn_jit::testgen_jit!([$($float),*], [$($int),*]);
mod kernel {
use super::*;
@ -79,7 +83,7 @@ macro_rules! testgen_all {
}
}
mod jit_fusion {
burn_jit::testgen_jit_fusion!();
burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*]);
}
};
}
@ -87,22 +91,32 @@ macro_rules! testgen_all {
#[macro_export]
macro_rules! testgen_jit {
() => {
use super::*;
use burn_tensor::{Float, Int};
$crate::testgen_jit!([Float], [Int]);
};
([$($float:ident),*], [$($int:ident),*]) => {
pub use super::*;
use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test};
pub type TestBackend = JitBackend<TestRuntime, f32, i32>;
pub type TestBackend2<F, I> = JitBackend<TestRuntime, F, I>;
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensor2<F, I, const D: usize> = burn_tensor::Tensor<TestBackend2<F, I>, D>;
pub type TestTensorInt<const D: usize> =
burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorInt2<F, I, const D: usize> =
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> =
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;
pub type TestTensorBool2<F, I, const D: usize> =
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Bool>;
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
burn_tensor::testgen_all!();
burn_autodiff::testgen_all!();
burn_tensor::testgen_all!([$($float),*], [$($int),*]);
burn_autodiff::testgen_all!([$($float),*]);
// Not all ops are implemented for quantization yet, notably missing:
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
@ -111,28 +125,38 @@ macro_rules! testgen_jit {
burn_tensor::testgen_calibration!();
burn_tensor::testgen_scheme!();
burn_tensor::testgen_quantize!();
};
}
}
#[macro_export]
macro_rules! testgen_jit_fusion {
() => {
use burn_tensor::{Float, Int};
$crate::testgen_jit_fusion!([Float], [Int]);
};
([$($float:ident),*], [$($int:ident),*]) => {
use super::*;
use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor};
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime, f32, i32>>;
pub type TestBackend2<F, I> = burn_fusion::Fusion<JitBackend<TestRuntime, F, I>>;
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensor2<F, I, const D: usize> = burn_tensor::Tensor<TestBackend2<F, I>, D>;
pub type TestTensorInt<const D: usize> =
burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorInt2<F, I, const D: usize> =
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> =
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;
pub type TestTensorBool2<F, I, const D: usize> =
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Bool>;
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
burn_tensor::testgen_all!();
burn_autodiff::testgen_all!();
burn_tensor::testgen_all!([$($float),*], [$($int),*]);
burn_autodiff::testgen_all!([$($float),*]);
// Not all ops are implemented for quantization yet, notably missing:
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`

View File

@ -2,20 +2,25 @@
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Tensor library with user-friendly APIs and automatic differentiation support"
documentation = "https://docs.rs/burn-tensor"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-tensor"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor"
documentation = "https://docs.rs/burn-tensor"
version.workspace = true
[features]
cubecl = ["dep:cubecl"]
cubecl-cuda = ["cubecl", "cubecl/cuda"]
cubecl-hip = ["cubecl", "cubecl/hip"]
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
default = ["std", "repr"]
doc = ["default"]
experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"]
export_tests = ["burn-tensor-testgen", "cubecl"]
repr = []
std = [
"rand/std",
"half/std",
@ -24,24 +29,19 @@ std = [
"burn-common/rayon",
"colored",
]
repr = []
cubecl = ["dep:cubecl"]
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
cubecl-cuda = ["cubecl", "cubecl/cuda"]
cubecl-hip = ["cubecl", "cubecl/hip"]
[dependencies]
burn-common = { path = "../burn-common", version = "0.16.0", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true }
cubecl = { workspace = true, optional = true }
cubecl = { workspace = true, optional = true, default-features = true }
bytemuck = { workspace = true }
colored = { workspace = true, optional = true }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
num-traits = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
bytemuck = { workspace = true }
colored = { workspace = true, optional = true }
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
hashbrown = { workspace = true } # no_std compatible
@ -50,6 +50,7 @@ hashbrown = { workspace = true } # no_std compatible
serde = { workspace = true }
serde_bytes = { workspace = true }
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }

View File

@ -20,7 +20,7 @@ pub mod repr;
#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
mod tests;
pub mod tests;
pub use half::{bf16, f16};
pub(crate) use tensor::check::macros::check;
@ -30,7 +30,7 @@ pub use burn_common::reader::*; // Useful so that backends don't have to add `bu
#[cfg(feature = "cubecl")]
mod cube {
use cubecl::ir::{Elem, FloatKind, IntKind};
use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind};
impl From<crate::DType> for cubecl::ir::Elem {
fn from(dtype: crate::DType) -> Self {
@ -41,11 +41,12 @@ mod cube {
crate::DType::BF16 => Elem::Float(FloatKind::BF16),
crate::DType::I64 => Elem::Int(IntKind::I64),
crate::DType::I32 => Elem::Int(IntKind::I32),
crate::DType::I16 => panic!("i16 isn't supported yet."),
crate::DType::I8 => panic!("i8 isn't supported yet."),
crate::DType::U64 => Elem::UInt,
crate::DType::U32 => Elem::UInt,
crate::DType::U8 => panic!("u8 isn't supported yet."),
crate::DType::I16 => Elem::Int(IntKind::I16),
crate::DType::I8 => Elem::Int(IntKind::I8),
crate::DType::U64 => Elem::UInt(UIntKind::U64),
crate::DType::U32 => Elem::UInt(UIntKind::U32),
crate::DType::U16 => Elem::UInt(UIntKind::U16),
crate::DType::U8 => Elem::UInt(UIntKind::U8),
crate::DType::Bool => Elem::Bool,
crate::DType::QFloat(_) => panic!("quantized type is not supported yet."),
}

View File

@ -1,4 +1,7 @@
use core::any::{Any, TypeId};
use core::{
any::{Any, TypeId},
f32,
};
use alloc::boxed::Box;
use alloc::format;
@ -181,6 +184,11 @@ impl TensorData {
.map(|e: &i64| e.elem::<E>()),
),
DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
DType::U16 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &u16| e.elem::<E>()),
),
DType::U32 => Box::new(
bytemuck::checked::cast_slice(&self.bytes)
.iter()
@ -313,6 +321,7 @@ impl TensorData {
DType::I8 => self.convert_inplace::<i8, E>(),
DType::U64 => self.convert_inplace::<u64, E>(),
DType::U32 => self.convert_inplace::<u32, E>(),
DType::U16 => self.convert_inplace::<u16, E>(),
DType::U8 => self.convert_inplace::<u8, E>(),
DType::Bool | DType::QFloat(_) => unreachable!(),
}
@ -419,6 +428,7 @@ impl TensorData {
DType::I8 => self.assert_eq_elem::<i8>(other),
DType::U64 => self.assert_eq_elem::<u64>(other),
DType::U32 => self.assert_eq_elem::<u32>(other),
DType::U16 => self.assert_eq_elem::<u16>(other),
DType::U8 => self.assert_eq_elem::<u8>(other),
DType::Bool => self.assert_eq_elem::<bool>(other),
DType::QFloat(q) => {
@ -511,9 +521,21 @@ impl TensorData {
continue;
}
let err = ((a - b).pow(2.0f64)).sqrt();
let err = (a - b).abs();
if err > tolerance || err.is_nan() {
if self.dtype.is_float() {
if let Some((err, tolerance)) = compare_floats(a, b, self.dtype, tolerance) {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!(
"\n => Position {i}: {a} != {b} | difference {err} > tolerance \
{tolerance}"
)
.as_str();
}
num_diff += 1;
}
} else if err > tolerance || err.is_nan() {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!(
@ -683,6 +705,7 @@ impl core::fmt::Display for TensorData {
DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
DType::QFloat(q) => match &q {
@ -869,7 +892,7 @@ impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
}
#[allow(deprecated)]
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E, D> {
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data<E, D> {
/// Asserts the data is approximately equal to another data.
///
/// # Arguments
@ -926,9 +949,21 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
continue;
}
let err = ((a - b).pow(2.0f64)).sqrt();
let err = (a - b).abs();
if err > tolerance || err.is_nan() {
if E::dtype().is_float() {
if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!(
"\n => Position {i}: {a} != {b} | difference {err} > tolerance \
{tolerance}"
)
.as_str();
}
num_diff += 1;
}
} else if err > tolerance || err.is_nan() {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!(
@ -1076,6 +1111,30 @@ impl<E: core::fmt::Debug, const D: usize> core::fmt::Display for Data<E, D> {
}
}
fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> {
let epsilon_deviations = tolerance / f32::EPSILON as f64;
let epsilon = match ty {
DType::F64 => f32::EPSILON as f64, // Don't increase precision beyond `f32`, see below
DType::F32 => f32::EPSILON as f64,
DType::F16 => half::f16::EPSILON.to_f64(),
DType::BF16 => half::bf16::EPSILON.to_f64(),
_ => unreachable!(),
};
let tolerance_norm = epsilon_deviations * epsilon;
// Clamp to 1.0 so we don't require more precision than `tolerance`. This is because literals
// have a fixed number of digits, so increasing precision breaks things
let value_abs = value.abs().max(1.0);
let tolerance_adjusted = tolerance_norm * value_abs;
let err = (value - other).abs();
if err > tolerance_adjusted || err.is_nan() {
Some((err, tolerance_adjusted))
} else {
None
}
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
@ -1140,16 +1199,16 @@ mod tests {
#[test]
fn should_assert_appox_eq_limit() {
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
let data2 = TensorData::from([[3.01, 5.0, 6.0]]);
let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
data1.assert_approx_eq(&data2, 2);
}
#[test]
#[should_panic]
fn should_assert_appox_eq_above_limit() {
fn should_assert_approx_eq_above_limit() {
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
let data2 = TensorData::from([[3.011, 5.0, 6.0]]);
let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
data1.assert_approx_eq(&data2, 2);
}

View File

@ -1,6 +1,8 @@
use core::cmp::Ordering;
use crate::{cast::ToElement, quantization::QuantizationStrategy, Distribution};
#[cfg(feature = "cubecl")]
use cubecl::flex32;
use half::{bf16, f16};
use rand::RngCore;
use serde::{Deserialize, Serialize};
@ -193,6 +195,14 @@ make_element!(
dtype DType::I16
);
make_element!(
ty u16 Precision::Half,
convert |elem: &dyn ToElement| elem.to_u16(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u16, b: &u16| Ord::cmp(a, b),
dtype DType::U16
);
make_element!(
ty i8 Precision::Other,
convert |elem: &dyn ToElement| elem.to_i8(),
@ -230,6 +240,18 @@ make_element!(
dtype DType::BF16
);
#[cfg(feature = "cubecl")]
make_element!(
ty flex32 Precision::Half,
convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
flex32::from_elem(sample)
},
cmp |a: &flex32, b: &flex32| a.total_cmp(b),
dtype DType::F32
);
make_element!(
ty bool Precision::Other,
convert |elem: &dyn ToElement| elem.to_u8() != 0,
@ -254,6 +276,7 @@ pub enum DType {
I8,
U64,
U32,
U16,
U8,
Bool,
QFloat(QuantizationStrategy),
@ -273,6 +296,7 @@ impl DType {
DType::I8 => core::mem::size_of::<i8>(),
DType::U64 => core::mem::size_of::<u64>(),
DType::U32 => core::mem::size_of::<u32>(),
DType::U16 => core::mem::size_of::<u16>(),
DType::U8 => core::mem::size_of::<u8>(),
DType::Bool => core::mem::size_of::<bool>(),
DType::QFloat(strategy) => match strategy {

View File

@ -442,6 +442,50 @@ impl ToElement for bf16 {
}
}
#[cfg(feature = "cubecl")]
impl ToElement for cubecl::flex32 {
#[inline]
fn to_i64(&self) -> i64 {
Self::to_f32(*self).to_i64()
}
#[inline]
fn to_u64(&self) -> u64 {
Self::to_f32(*self).to_u64()
}
#[inline]
fn to_i8(&self) -> i8 {
Self::to_f32(*self).to_i8()
}
#[inline]
fn to_u8(&self) -> u8 {
Self::to_f32(*self).to_u8()
}
#[inline]
fn to_i16(&self) -> i16 {
Self::to_f32(*self).to_i16()
}
#[inline]
fn to_u16(&self) -> u16 {
Self::to_f32(*self).to_u16()
}
#[inline]
fn to_i32(&self) -> i32 {
Self::to_f32(*self).to_i32()
}
#[inline]
fn to_u32(&self) -> u32 {
Self::to_f32(*self).to_u32()
}
#[inline]
fn to_f32(&self) -> f32 {
Self::to_f32(*self)
}
#[inline]
fn to_f64(&self) -> f64 {
Self::to_f64(*self)
}
}
impl ToElement for bool {
#[inline]
fn to_i64(&self) -> i64 {

View File

@ -15,7 +15,7 @@ mod tests {
#[test]
fn test_hard_sigmoid_overflow() {
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
let tensor = TestTensor::<1>::from([FloatType::MAX, FloatType::MIN]);
let output = activation::hard_sigmoid(tensor, 0.2, 0.5);
let expected = TensorData::from([1.0, 0.0]);

View File

@ -9,9 +9,9 @@ mod tests {
let output = activation::leaky_relu(tensor, 0.01);
output.into_data().assert_eq(
&TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]),
false,
);
// Account for conversion errors if `FloatType != f32`
output
.into_data()
.assert_approx_eq(&TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]), 5);
}
}

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(log_sigmoid)]
mod tests {
use super::*;
use burn_tensor::{activation, Tensor, TensorData};
use burn_tensor::{activation, cast::ToElement, tests::Numeric, Tensor, TensorData};
#[test]
fn test_log_sigmoid() {
@ -10,7 +10,7 @@ mod tests {
let output = activation::log_sigmoid(tensor);
let expected = TensorData::from([[-3.132617e-1, -9.114665e-4], [-2.260327e-6, -3.0485873]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
#[test]
@ -23,9 +23,10 @@ mod tests {
let expected = TensorData::from([0.0, -300.0]);
output.into_data().assert_approx_eq(&expected, 4);
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
let tensor =
TestTensor::<1>::from([<FloatType as Numeric>::MAX, <FloatType as Numeric>::MIN]);
let output = activation::log_sigmoid(tensor);
let expected = TensorData::from([0.0, f32::MIN]);
let expected = TensorData::from([0.0, <FloatType as Numeric>::MIN.to_f32()]);
output.into_data().assert_approx_eq(&expected, 4);
}

View File

@ -11,6 +11,6 @@ mod tests {
let output = activation::mish(tensor);
let expected = TensorData::from([[-0.1971, -0.3006, -0.1172], [-0.2413, 0.5823, -0.0888]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(sigmoid)]
mod tests {
use super::*;
use burn_tensor::{activation, Tensor, TensorData};
use burn_tensor::{activation, tests::Numeric, Tensor, TensorData};
#[test]
fn test_sigmoid() {
@ -10,12 +10,12 @@ mod tests {
let output = activation::sigmoid(tensor);
let expected = TensorData::from([[0.7311, 0.9991], [1.0, 0.0474]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
#[test]
fn test_sigmoid_overflow() {
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
let tensor = TestTensor::<1>::from([FloatType::MAX, FloatType::MIN]);
let output = activation::sigmoid(tensor);
let expected = TensorData::from([1.0, 0.0]);

View File

@ -10,6 +10,6 @@ mod tests {
let output = activation::silu(tensor);
let expected = TensorData::from([[0.7311, 1.7616], [2.8577, 3.9281]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -10,6 +10,6 @@ mod tests {
let output = activation::softmax(tensor, 1);
let expected = TensorData::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -10,6 +10,6 @@ mod tests {
let output = activation::softmin(tensor, 1);
let expected = TensorData::from([[9.9753e-01, 2.4726e-03], [1.1254e-07, 1.0000e+00]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -5,19 +5,17 @@ mod tests {
#[test]
fn test_softplus_d2() {
let tensor = Tensor::<TestBackend, 2>::from([
[-0.4240, -0.9574, -0.2215],
[-0.5767, 0.7218, -0.1620],
]);
let tensor =
TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);
let output = activation::softplus(tensor.clone(), 1.0);
let expected = TensorData::from([[0.5034, 0.3249, 0.5885], [0.4458, 1.1178, 0.6154]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
let output = activation::softplus(tensor, 2.0);
let expected = TensorData::from([[0.1782, 0.0687, 0.2480], [0.1371, 0.8277, 0.2721]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -10,6 +10,6 @@ mod tests {
let output = activation::tanh(tensor);
let expected = TensorData::from([[0.7616, 0.9640], [0.9951, 0.9993]]);
output.into_data().assert_approx_eq(&expected, 4);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -5,122 +5,59 @@ mod ops;
mod quantization;
mod stats;
pub use cubecl::prelude::{Float, Int, Numeric};
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_all {
() => {
// test activation
burn_tensor::testgen_gelu!();
burn_tensor::testgen_mish!();
burn_tensor::testgen_relu!();
burn_tensor::testgen_leaky_relu!();
burn_tensor::testgen_softmax!();
burn_tensor::testgen_softmin!();
burn_tensor::testgen_softplus!();
burn_tensor::testgen_sigmoid!();
burn_tensor::testgen_log_sigmoid!();
burn_tensor::testgen_silu!();
burn_tensor::testgen_tanh_activation!();
pub mod tensor {
pub use super::*;
// test module
burn_tensor::testgen_module_forward!();
burn_tensor::testgen_module_conv1d!();
burn_tensor::testgen_module_conv2d!();
burn_tensor::testgen_module_conv3d!();
burn_tensor::testgen_module_deform_conv2d!();
burn_tensor::testgen_module_conv_transpose1d!();
burn_tensor::testgen_module_conv_transpose2d!();
burn_tensor::testgen_module_conv_transpose3d!();
burn_tensor::testgen_module_unfold4d!();
burn_tensor::testgen_module_max_pool1d!();
burn_tensor::testgen_module_max_pool2d!();
burn_tensor::testgen_module_avg_pool1d!();
burn_tensor::testgen_module_avg_pool2d!();
burn_tensor::testgen_module_adaptive_avg_pool1d!();
burn_tensor::testgen_module_adaptive_avg_pool2d!();
burn_tensor::testgen_module_nearest_interpolate!();
burn_tensor::testgen_module_bilinear_interpolate!();
burn_tensor::testgen_module_bicubic_interpolate!();
pub type FloatType = <TestBackend as $crate::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as $crate::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as $crate::backend::Backend>::BoolTensorPrimitive;
// test ops
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arange!();
burn_tensor::testgen_arange_step!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_chunk!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_close!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_create_like!();
burn_tensor::testgen_div!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
burn_tensor::testgen_gather_scatter!();
burn_tensor::testgen_init!();
burn_tensor::testgen_iter_dim!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_mask!();
burn_tensor::testgen_matmul!();
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_narrow!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_one_hot!();
burn_tensor::testgen_powf_scalar!();
burn_tensor::testgen_random!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();
burn_tensor::testgen_sin!();
burn_tensor::testgen_slice!();
burn_tensor::testgen_stack!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_abs!();
burn_tensor::testgen_squeeze!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_transpose!();
burn_tensor::testgen_tri!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_any!();
burn_tensor::testgen_all_op!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_movedim!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
burn_tensor::testgen_expand!();
burn_tensor::testgen_tri_mask!();
burn_tensor::testgen_sort_argsort!();
burn_tensor::testgen_topk!();
burn_tensor::testgen_remainder!();
burn_tensor::testgen_cartesian_grid!();
burn_tensor::testgen_nan!();
burn_tensor::testgen_round!();
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
$crate::testgen_with_float_param!();
$crate::testgen_no_param!();
}
};
([$($float:ident),*], [$($int:ident),*]) => {
pub mod tensor {
pub use super::*;
// test stats
burn_tensor::testgen_var!();
burn_tensor::testgen_cov!();
burn_tensor::testgen_eye!();
burn_tensor::testgen_display!();
pub type FloatType = <TestBackend as $crate::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as $crate::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as $crate::backend::Backend>::BoolTensorPrimitive;
// test clone invariance
burn_tensor::testgen_clone_invariance!();
::paste::paste! {
$(mod [<$float _ty>] {
pub use super::*;
// test padding
burn_tensor::testgen_padding!();
pub type TestBackend = TestBackend2<$float, IntType>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;
pub type FloatType = $float;
$crate::testgen_with_float_param!();
})*
$(mod [<$int _ty>] {
pub use super::*;
pub type TestBackend = TestBackend2<FloatType, $int>;
pub type TestTensor<const D: usize> = TestTensor2<FloatType, $int, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<FloatType, $int, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<FloatType, $int, D>;
pub type IntType = $int;
$crate::testgen_with_int_param!();
})*
}
$crate::testgen_no_param!();
}
};
}
@ -178,3 +115,191 @@ macro_rules! testgen_quantization {
burn_tensor::testgen_q_transpose!();
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_with_float_param {
() => {
// test activation
burn_tensor::testgen_gelu!();
burn_tensor::testgen_mish!();
burn_tensor::testgen_relu!();
burn_tensor::testgen_leaky_relu!();
burn_tensor::testgen_softmax!();
burn_tensor::testgen_softmin!();
burn_tensor::testgen_softplus!();
burn_tensor::testgen_sigmoid!();
burn_tensor::testgen_log_sigmoid!();
burn_tensor::testgen_silu!();
burn_tensor::testgen_tanh_activation!();
// test module
burn_tensor::testgen_module_conv1d!();
burn_tensor::testgen_module_conv2d!();
burn_tensor::testgen_module_conv3d!();
burn_tensor::testgen_module_forward!();
burn_tensor::testgen_module_deform_conv2d!();
burn_tensor::testgen_module_conv_transpose1d!();
burn_tensor::testgen_module_conv_transpose2d!();
burn_tensor::testgen_module_conv_transpose3d!();
burn_tensor::testgen_module_unfold4d!();
burn_tensor::testgen_module_max_pool1d!();
burn_tensor::testgen_module_max_pool2d!();
burn_tensor::testgen_module_avg_pool1d!();
burn_tensor::testgen_module_avg_pool2d!();
burn_tensor::testgen_module_adaptive_avg_pool1d!();
burn_tensor::testgen_module_adaptive_avg_pool2d!();
burn_tensor::testgen_module_nearest_interpolate!();
burn_tensor::testgen_module_bilinear_interpolate!();
burn_tensor::testgen_module_bicubic_interpolate!();
// test ops
burn_tensor::testgen_gather_scatter!();
burn_tensor::testgen_narrow!();
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arange!();
burn_tensor::testgen_arange_step!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_chunk!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_close!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_create_like!();
burn_tensor::testgen_div!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
burn_tensor::testgen_init!();
burn_tensor::testgen_iter_dim!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_mask!();
burn_tensor::testgen_matmul!();
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_one_hot!();
burn_tensor::testgen_powf_scalar!();
burn_tensor::testgen_random!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_sin!();
burn_tensor::testgen_slice!();
burn_tensor::testgen_stack!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_abs!();
burn_tensor::testgen_squeeze!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_transpose!();
burn_tensor::testgen_tri!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_any!();
burn_tensor::testgen_all_op!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_movedim!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
burn_tensor::testgen_expand!();
burn_tensor::testgen_tri_mask!();
burn_tensor::testgen_sort_argsort!();
burn_tensor::testgen_topk!();
burn_tensor::testgen_remainder!();
burn_tensor::testgen_cartesian_grid!();
burn_tensor::testgen_nan!();
burn_tensor::testgen_round!();
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
burn_tensor::testgen_select!();
// test stats
burn_tensor::testgen_var!();
burn_tensor::testgen_cov!();
burn_tensor::testgen_eye!();
// test padding
burn_tensor::testgen_padding!();
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_with_int_param {
() => {
// test ops
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_div!();
burn_tensor::testgen_expand!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_mask!();
burn_tensor::testgen_movedim!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();
burn_tensor::testgen_sign!();
burn_tensor::testgen_sort_argsort!();
burn_tensor::testgen_stack!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_transpose!();
burn_tensor::testgen_gather_scatter!();
// test stats
burn_tensor::testgen_eye!();
// test padding
burn_tensor::testgen_padding!();
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_no_param {
() => {
// test stats
burn_tensor::testgen_display!();
// test clone invariance
burn_tensor::testgen_clone_invariance!();
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! as_bytes {
($ty:ident: $($elem:expr),*) => {
F::as_bytes(&[$($ty::new($elem),)*])
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! as_type {
($ty:ident: [$($elem:tt),*]) => {
[$($crate::as_type![$ty: $elem]),*]
};
($ty:ident: [$($elem:tt,)*]) => {
[$($crate::as_type![$ty: $elem]),*]
};
($ty:ident: $elem:expr) => {
{
use $crate::tests::{Float, Int};
$ty::new($elem)
}
};
}

View File

@ -106,7 +106,7 @@ mod tests {
self.count_include_pad,
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
y.to_data().assert_approx_eq(&output.into_data(), 2);
}
}
}

View File

@ -111,7 +111,7 @@ mod tests {
!output
.clone()
.to_data()
.as_slice::<f32>()
.as_slice::<FloatType>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
@ -148,7 +148,7 @@ mod tests {
InterpolateOptions::new(InterpolateMode::Bicubic),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.3);
}
}
}

View File

@ -111,7 +111,7 @@ mod tests {
!output
.clone()
.to_data()
.as_slice::<f32>()
.as_slice::<FloatType>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
@ -156,7 +156,7 @@ mod tests {
InterpolateOptions::new(InterpolateMode::Bilinear),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.3);
}
}
}

View File

@ -26,7 +26,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[[0.9074, 0.6387], [0.5160, 0.4196]],
[[2.4259, 1.8008], [1.5449, 1.3112]],
[[3.9444, 2.9629], [2.5738, 2.2027]],
@ -55,7 +55,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([
test.assert_output(TestTensor::<4>::from([
[
[[0.2155, 0.1928], [0.1934, 0.1755]],
[[0.7251, 0.6759], [0.6877, 0.6485]],
@ -93,7 +93,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[[0.1018, 0.0658], [0.0467, 0.0362]],
[[0.4125, 0.3367], [0.3069, 0.2824]],
[[1.3076, 1.0242], [0.9025, 0.8000]],
@ -123,7 +123,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[[1.0794, 0.7676], [0.7209, 0.5337]],
[[2.7059, 2.0216], [1.9740, 1.5419]],
[[4.3325, 3.2755], [3.2271, 2.5501]],
@ -153,7 +153,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[[1.0669], [0.6329]],
[[2.9741], [2.0383]],
[[4.8812], [3.4437]],
@ -180,7 +180,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[
[
0.1998, 0.3762, 0.5285, 0.6053, 0.3844, 0.1987, 0.0481, 0.0000,
@ -264,7 +264,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[[1.0647], [0.5783]],
[[2.9289], [1.8829]],
[[4.7931], [3.1875]],
@ -291,7 +291,7 @@ mod tests {
width: 5,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[[0.6162], [0.7611], [0.4666]],
[[1.8578], [2.2684], [1.6208]],
[[3.0994], [3.7757], [2.7749]],
@ -318,7 +318,7 @@ mod tests {
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
test.assert_output(TestTensor::<4>::from([[
[
[0.8909, 0.6016],
[1.0697, 0.7186],
@ -389,29 +389,29 @@ mod tests {
out_width,
]);
let device = Default::default();
let weight = Tensor::<TestBackend, 4>::from(
let weight = TestTensor::<4>::from(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight.clone())
.into_data(),
)
.div_scalar(shape_weight.num_elements() as f32);
let bias = Tensor::<TestBackend, 1>::from(
let bias = TestTensor::<1>::from(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
)
.div_scalar(self.channels_out as f32);
let x = Tensor::<TestBackend, 4>::from(
let x = TestTensor::<4>::from(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x.clone())
.into_data(),
)
.div_scalar(shape_x.num_elements() as f32);
let offset = Tensor::<TestBackend, 4>::from(
let offset = TestTensor::<4>::from(
TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
.reshape::<4, _>(shape_offset.clone())
.into_data(),
)
.div_scalar(shape_offset.num_elements() as f32);
let mask = Tensor::<TestBackend, 4>::from(
let mask = TestTensor::<4>::from(
TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
.reshape::<4, _>(shape_mask.clone())
.into_data(),
@ -433,7 +433,7 @@ mod tests {
),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.04);
}
}
}

View File

@ -7,8 +7,8 @@ mod tests {
fn test_embedding_forward() {
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indices = TensorData::from([[0, 1], [1, 1]]);
let weights = Tensor::<TestBackend, 2>::from(weights);
let indices = Tensor::<TestBackend, 2, Int>::from(indices);
let weights = TestTensor::<2>::from(weights);
let indices = TestTensorInt::<2>::from(indices);
let output = embedding(weights, indices);
let expected = TensorData::from([

View File

@ -84,7 +84,7 @@ mod tests {
!output
.clone()
.to_data()
.as_slice::<f32>()
.as_slice::<FloatType>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
@ -122,7 +122,7 @@ mod tests {
InterpolateOptions::new(InterpolateMode::Nearest),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.2);
}
}
}

View File

@ -8,24 +8,24 @@ mod tests {
fn test_arange() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..5, &device);
let tensor = TestTensorInt::<1>::arange(2..5, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([2, 3, 4]), false);
// Test arange with negative numbers
let tensor = Tensor::<TestBackend, 1, Int>::arange(-10..-5, &device);
let tensor = TestTensorInt::<1>::arange(-10..-5, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([-10, -9, -8, -7, -6]), false);
let tensor = Tensor::<TestBackend, 1, Int>::arange(-3..0, &device);
let tensor = TestTensorInt::<1>::arange(-3..0, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([-3, -2, -1]), false);
// Test arange with a mix of positive and negative numbers
let tensor = Tensor::<TestBackend, 1, Int>::arange(-2..3, &device);
let tensor = TestTensorInt::<1>::arange(-2..3, &device);
tensor
.clone()
.into_data()

View File

@ -9,28 +9,28 @@ mod tests {
let device = <TestBackend as Backend>::Device::default();
// Test correct sequence of numbers when the range is 0..9 and the step is 1
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..9, 1, &device);
let tensor = TestTensorInt::<1>::arange_step(0..9, 1, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([0, 1, 2, 3, 4, 5, 6, 7, 8]), false);
// Test correct sequence of numbers when the range is 0..3 and the step is 2
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 2, &device);
let tensor = TestTensorInt::<1>::arange_step(0..3, 2, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([0, 2]), false);
// Test correct sequence of numbers when the range is 0..2 and the step is 5
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..2, 5, &device);
let tensor = TestTensorInt::<1>::arange_step(0..2, 5, &device);
tensor.into_data().assert_eq(&TensorData::from([0]), false);
// Test correct sequence of numbers when the range includes negative numbers
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-3..3, 2, &device);
let tensor = TestTensorInt::<1>::arange_step(-3..3, 2, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([-3, -1, 1]), false);
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-5..1, 5, &device);
let tensor = TestTensorInt::<1>::arange_step(-5..1, 5, &device);
tensor
.clone()
.into_data()
@ -43,6 +43,6 @@ mod tests {
fn should_panic_when_step_is_zero() {
let device = <TestBackend as Backend>::Device::default();
// Test that arange_step panics when the step is 0
let _tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 0, &device);
let _tensor = TestTensorInt::<1>::arange_step(0..3, 0, &device);
}
}

View File

@ -9,15 +9,14 @@ mod tests {
let device = <TestBackend as Backend>::Device::default();
// Test a single element tensor
let tensor: Tensor<TestBackend, 2, Int> =
Tensor::<TestBackend, 1, Int>::cartesian_grid([1], &device);
let tensor: Tensor<TestBackend, 2, Int> = TestTensorInt::<1>::cartesian_grid([1], &device);
tensor
.into_data()
.assert_eq(&TensorData::from([[0]]), false);
// Test for a 2x2 tensor
let tensor: Tensor<TestBackend, 3, Int> =
Tensor::<TestBackend, 2, Int>::cartesian_grid([2, 2], &device);
TestTensorInt::<2>::cartesian_grid([2, 2], &device);
tensor.into_data().assert_eq(
&TensorData::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]),
false,

View File

@ -31,9 +31,7 @@ mod tests {
#[test]
fn cast_bool_to_float_tensor() {
let tensor =
Tensor::<TestBackend, 2, Bool>::from([[true, false, true], [false, false, true]])
.float();
let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).float();
let expected = TensorData::from([[1., 0., 1.], [0., 0., 1.]]);

View File

@ -18,8 +18,8 @@ mod tests {
#[test]
fn should_support_cat_ops_int() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 3]], &device);
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data([[4, 5, 6]], &device);
let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device);
let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device);
let output = Tensor::cat(vec![tensor_1, tensor_2], 0);
@ -31,8 +31,8 @@ mod tests {
#[test]
fn should_support_cat_ops_bool() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data([[false, true, true]], &device);
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data([[true, true, false]], &device);
let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);
let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);
let output = Tensor::cat(vec![tensor_1, tensor_2], 0);

View File

@ -6,7 +6,7 @@ mod tests {
#[test]
fn should_support_ceil_ops() {
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let tensor = TestTensor::<2>::from_data(data, &Default::default());
let output = tensor.ceil();
let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]);

View File

@ -8,7 +8,7 @@ mod tests {
let device = Default::default();
// test float tensor
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
let tensor = TestTensor::<2>::from_data(data, &device);
let output = tensor.clamp_min(2.0);
@ -18,7 +18,7 @@ mod tests {
// test int tensor
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &device);
let tensor = TestTensorInt::<2>::from_data(data, &device);
let output = tensor.clamp_min(2);
output
@ -31,7 +31,7 @@ mod tests {
let device = Default::default();
// test float tensor
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
let tensor = TestTensor::<2>::from_data(data, &device);
let output = tensor.clamp_max(2.0);
@ -41,7 +41,7 @@ mod tests {
// test int tensor
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &device);
let tensor = TestTensorInt::<2>::from_data(data, &device);
let output = tensor.clamp_max(4);
output
@ -54,7 +54,7 @@ mod tests {
let device = Default::default();
// test float tensor
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
let tensor = TestTensor::<2>::from_data(data, &device);
let output = tensor.clamp(1.0, 4.0);
output
@ -63,7 +63,7 @@ mod tests {
// test int tensor
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &device);
let tensor = TestTensorInt::<2>::from_data(data, &device);
let output = tensor.clamp(1, 4);
output

View File

@ -6,7 +6,7 @@ mod tests {
#[test]
fn should_support_cos_ops() {
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let tensor = TestTensor::<2>::from_data(data, &Default::default());
let output = tensor.cos();
let expected = TensorData::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]);

View File

@ -8,8 +8,8 @@ mod tests {
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
let output = tensor_1 / tensor_2;
let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);
@ -22,8 +22,8 @@ mod tests {
let data_1 = TensorData::from([[0.0, 1.0, 2.0]]);
let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
let output = tensor_1 / tensor_2;
@ -38,7 +38,7 @@ mod tests {
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let scalar = 2.0;
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
let tensor = TestTensor::<2>::from_data(data, &device);
let output = tensor / scalar;
@ -52,8 +52,8 @@ mod tests {
let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]);
let data_2 = TensorData::from([[1, 1, 2], [1, 1, 2]]);
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data(data_1, &device);
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data(data_2, &device);
let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);
let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);
let output = tensor_1 / tensor_2;
@ -67,8 +67,8 @@ mod tests {
let data_1 = TensorData::from([[0, 1, 2]]);
let data_2 = TensorData::from([[1, 1, 2], [3, 4, 5]]);
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data(data_1, &device);
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data(data_2, &device);
let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);
let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);
let output = tensor_1 / tensor_2;
@ -81,7 +81,7 @@ mod tests {
fn should_support_div_scalar_ops_int() {
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
let scalar = 2;
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
let tensor = TestTensorInt::<2>::from_data(data, &Default::default());
let output = tensor / scalar;

Some files were not shown because too many files have changed in this diff Show More