mirror of https://github.com/tracel-ai/burn.git
Add more type support for burn-jit (#2454)
This commit is contained in:
parent
5597657314
commit
42f39f16b3
|
@ -313,7 +313,7 @@ dependencies = [
|
|||
"clap 4.5.20",
|
||||
"colored",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"dirs 5.0.1",
|
||||
"github-device-flow",
|
||||
"half",
|
||||
|
@ -321,7 +321,7 @@ dependencies = [
|
|||
"os_info",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"reqwest 0.12.8",
|
||||
"reqwest 0.12.9",
|
||||
"rstest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -503,7 +503,7 @@ dependencies = [
|
|||
"burn-common",
|
||||
"burn-tensor",
|
||||
"burn-tensor-testgen",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"log",
|
||||
"spin",
|
||||
]
|
||||
|
@ -516,7 +516,7 @@ dependencies = [
|
|||
"burn-tch",
|
||||
"burn-tensor",
|
||||
"candle-core",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"half",
|
||||
]
|
||||
|
||||
|
@ -529,7 +529,7 @@ dependencies = [
|
|||
"getrandom",
|
||||
"indicatif",
|
||||
"rayon",
|
||||
"reqwest 0.12.8",
|
||||
"reqwest 0.12.9",
|
||||
"tokio",
|
||||
"web-time",
|
||||
]
|
||||
|
@ -552,7 +552,7 @@ dependencies = [
|
|||
"burn-tensor",
|
||||
"burn-wgpu",
|
||||
"data-encoding",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"flate2",
|
||||
"half",
|
||||
"hashbrown 0.15.0",
|
||||
|
@ -579,9 +579,10 @@ dependencies = [
|
|||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"half",
|
||||
"log",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -590,7 +591,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn-common",
|
||||
"csv",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"dirs 5.0.1",
|
||||
"fake",
|
||||
"flate2",
|
||||
|
@ -620,7 +621,7 @@ dependencies = [
|
|||
name = "burn-derive"
|
||||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
|
@ -632,7 +633,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn-common",
|
||||
"burn-tensor",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"half",
|
||||
"hashbrown 0.15.0",
|
||||
"log",
|
||||
|
@ -649,7 +650,7 @@ dependencies = [
|
|||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
@ -660,7 +661,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn",
|
||||
"candle-core",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"half",
|
||||
"log",
|
||||
"onnx-ir",
|
||||
|
@ -691,12 +692,13 @@ dependencies = [
|
|||
"burn-tensor-testgen",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"futures-lite",
|
||||
"half",
|
||||
"hashbrown 0.15.0",
|
||||
"log",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"rand",
|
||||
"serde",
|
||||
"serial_test",
|
||||
|
@ -713,7 +715,7 @@ dependencies = [
|
|||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-tensor",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"libm",
|
||||
"matrixmultiply",
|
||||
"ndarray 0.16.1",
|
||||
|
@ -767,7 +769,7 @@ dependencies = [
|
|||
"bytemuck",
|
||||
"colored",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"half",
|
||||
"hashbrown 0.15.0",
|
||||
"num-traits",
|
||||
|
@ -792,7 +794,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn-core",
|
||||
"burn-ndarray",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"log",
|
||||
"nvml-wrapper",
|
||||
"ratatui",
|
||||
|
@ -812,6 +814,8 @@ dependencies = [
|
|||
"burn-jit",
|
||||
"burn-tensor",
|
||||
"cubecl",
|
||||
"half",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -881,15 +885,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.6.0"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
|
||||
checksum = "7e1a39b963e261c58017edf2007e5b63425ad21538aaaf51fe23d1da41703701"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"byteorder",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"cudarc 0.11.9",
|
||||
"cudarc",
|
||||
"gemm",
|
||||
"half",
|
||||
"libc",
|
||||
|
@ -908,18 +912,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.6.0"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
|
||||
checksum = "539cbfbf2d1d68a6ed97115e579c77c98f8ed0cfe7edbc6d7d30d2ac0c9e3d50"
|
||||
dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.6.0"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f889aacd02fd999620a0435133d7cf3b58c81ef9dd5e47c38939b7a72345ea86"
|
||||
checksum = "166a92826d615d98b205e346e52128fa0439f2ab3302587403fdc558b4219e19"
|
||||
dependencies = [
|
||||
"metal 0.27.0",
|
||||
"once_cell",
|
||||
|
@ -1460,7 +1464,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cuda",
|
||||
|
@ -1476,7 +1480,7 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"embassy-futures",
|
||||
"futures-lite",
|
||||
"getrandom",
|
||||
|
@ -1491,9 +1495,9 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"embassy-futures",
|
||||
"futures-lite",
|
||||
"getrandom",
|
||||
|
@ -1508,13 +1512,14 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-macros",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"derive_more 1.0.0",
|
||||
"half",
|
||||
"log",
|
||||
"num-traits",
|
||||
|
@ -1525,13 +1530,13 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-cpp"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
@ -1539,15 +1544,15 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-cpp",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"cudarc 0.12.1",
|
||||
"derive-new",
|
||||
"cudarc",
|
||||
"derive-new 0.6.0",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
@ -1555,7 +1560,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-hip"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common 0.4.0",
|
||||
|
@ -1563,7 +1568,7 @@ dependencies = [
|
|||
"cubecl-cpp",
|
||||
"cubecl-hip-sys",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
@ -1580,7 +1585,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-linalg"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-core",
|
||||
|
@ -1591,11 +1596,11 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"cubecl-common 0.4.0",
|
||||
"darling",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"ident_case",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
|
@ -1606,7 +1611,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-opt"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
|
@ -1628,7 +1633,7 @@ dependencies = [
|
|||
"async-lock",
|
||||
"cfg_aliases 0.2.1",
|
||||
"cubecl-common 0.3.0",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
|
@ -1643,13 +1648,13 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"async-lock",
|
||||
"cfg_aliases 0.2.1",
|
||||
"cubecl-common 0.4.0",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
|
@ -1664,12 +1669,13 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-spirv"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-opt",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"half",
|
||||
"hashbrown 0.14.5",
|
||||
"rspirv",
|
||||
]
|
||||
|
@ -1677,7 +1683,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9460a3244aa2b42e1d6c36bd25b65f814f81ecd0#9460a3244aa2b42e1d6c36bd25b65f814f81ecd0"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"async-channel",
|
||||
|
@ -1688,29 +1694,20 @@ dependencies = [
|
|||
"cubecl-core",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"cubecl-spirv",
|
||||
"derive-new",
|
||||
"derive-new 0.6.0",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"web-time",
|
||||
"wgpu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a5bd4d1eee570c3b2ac64ed114125517dd1e541d88dd28fc259f1de4dba8d60"
|
||||
dependencies = [
|
||||
"half",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
|
||||
dependencies = [
|
||||
"half",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
|
@ -1720,7 +1717,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn",
|
||||
"csv",
|
||||
"reqwest 0.12.8",
|
||||
"reqwest 0.12.9",
|
||||
"serde",
|
||||
]
|
||||
|
||||
|
@ -1732,7 +1729,7 @@ dependencies = [
|
|||
"burn-jit",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"log",
|
||||
"serde",
|
||||
]
|
||||
|
@ -1752,7 +1749,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn",
|
||||
"bytemuck",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"guide",
|
||||
"log",
|
||||
"serde",
|
||||
|
@ -1764,7 +1761,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn",
|
||||
"bytemuck",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"guide",
|
||||
"log",
|
||||
"serde",
|
||||
|
@ -1777,7 +1774,7 @@ dependencies = [
|
|||
"burn",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"log",
|
||||
"serde",
|
||||
]
|
||||
|
@ -1874,6 +1871,17 @@ dependencies = [
|
|||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive-new"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_arbitrary"
|
||||
version = "1.3.2"
|
||||
|
@ -1927,6 +1935,27 @@ dependencies = [
|
|||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deunicode"
|
||||
version = "1.6.0"
|
||||
|
@ -2224,9 +2253,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
|||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.5"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
|
||||
checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb"
|
||||
dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
|
@ -3191,9 +3220,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.9"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b"
|
||||
checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
|
@ -5669,9 +5698,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.8"
|
||||
version = "0.12.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b"
|
||||
checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
|
@ -5843,9 +5872,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.37"
|
||||
version = "0.38.38"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811"
|
||||
checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"errno",
|
||||
|
@ -5856,9 +5885,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.15"
|
||||
version = "0.23.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993"
|
||||
checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
|
@ -5970,9 +5999,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "scc"
|
||||
version = "2.2.2"
|
||||
version = "2.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2c1f7fc6deb21665a9060dfc7d271be784669295a31babdcd4dd2c79ae8cbfb"
|
||||
checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8"
|
||||
dependencies = [
|
||||
"sdd",
|
||||
]
|
||||
|
@ -6649,7 +6678,7 @@ name = "text-classification"
|
|||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"serde",
|
||||
"tokenizers",
|
||||
]
|
||||
|
@ -6659,7 +6688,7 @@ name = "text-generation"
|
|||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"derive-new",
|
||||
"derive-new 0.7.0",
|
||||
"log",
|
||||
"serde",
|
||||
"tokenizers",
|
||||
|
@ -6946,7 +6975,7 @@ checksum = "4126466aafe1c518cb5c23979c286903cb1d1ff1bc3b76891254a243a0ed1e15"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"clap 4.5.20",
|
||||
"derive_more",
|
||||
"derive_more 0.99.18",
|
||||
"env_logger",
|
||||
"log",
|
||||
"rand",
|
||||
|
|
|
@ -29,7 +29,7 @@ version = "0.16.0"
|
|||
[workspace.dependencies]
|
||||
atomic_float = "1"
|
||||
bytemuck = "1.19.0"
|
||||
candle-core = { version = "0.6.0" }
|
||||
candle-core = { version = "0.7" }
|
||||
clap = { version = "4.5.20", features = ["derive"] }
|
||||
colored = "2.1.0"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
|
@ -53,6 +53,7 @@ js-sys = "0.3.69"
|
|||
libm = "0.2.9"
|
||||
log = { default-features = false, version = "0.4.22" }
|
||||
md5 = "0.7.0"
|
||||
paste = "1"
|
||||
percent-encoding = "2.3.1"
|
||||
polars = { version = "0.41.3", features = ["lazy"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
|
@ -117,7 +118,7 @@ bincode = { version = "2.0.0-rc.3", features = [
|
|||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
derive-new = { version = "0.6.0", default-features = false }
|
||||
derive-new = { version = "0.7.0", default-features = false }
|
||||
|
||||
blas-src = { version = "0.10.0", default-features = false }
|
||||
half = { version = "2.4.1", features = [
|
||||
|
@ -152,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
|
|||
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
|
||||
|
|
|
@ -73,9 +73,9 @@ impl AutodiffClient for ChannelClient {
|
|||
.unwrap()
|
||||
}
|
||||
|
||||
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
|
||||
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
|
||||
let node_id = root.node.id;
|
||||
let grads = Gradients::new::<B, D>(root.node, root.primitive);
|
||||
let grads = Gradients::new::<B>(root.node, root.primitive);
|
||||
let (callback, receiver) = std::sync::mpsc::channel();
|
||||
|
||||
self.sender
|
||||
|
|
|
@ -20,10 +20,10 @@ mod tests {
|
|||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
|
||||
grad_1.to_data().assert_approx_eq(&expected, 3);
|
||||
grad_1.to_data().assert_approx_eq(&expected, 5);
|
||||
|
||||
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
|
||||
grad_2.to_data().assert_approx_eq(&expected, 3);
|
||||
grad_2.to_data().assert_approx_eq(&expected, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -42,10 +42,10 @@ mod tests {
|
|||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
|
||||
grad_1.to_data().assert_approx_eq(&expected, 3);
|
||||
grad_1.to_data().assert_approx_eq(&expected, 5);
|
||||
|
||||
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
|
||||
grad_2.to_data().assert_approx_eq(&expected, 3);
|
||||
grad_2.to_data().assert_approx_eq(&expected, 5);
|
||||
|
||||
let contains_nan = grad_2.contains_nan();
|
||||
assert_eq!(contains_nan.into_scalar(), false);
|
||||
|
|
|
@ -46,7 +46,7 @@ mod tests {
|
|||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ mod tests {
|
|||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ mod tests {
|
|||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ mod tests {
|
|||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,21 +40,21 @@ mod tests {
|
|||
.clone()
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_1_slice_1.to_data(), 3);
|
||||
.assert_approx_eq(&grad_1_slice_1.to_data(), 5);
|
||||
grad_1
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_1_slice_2.to_data(), 3);
|
||||
.assert_approx_eq(&grad_1_slice_2.to_data(), 5);
|
||||
|
||||
grad_2
|
||||
.clone()
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_2_slice_1.to_data(), 3);
|
||||
.assert_approx_eq(&grad_2_slice_1.to_data(), 5);
|
||||
grad_2
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_2_slice_2.to_data(), 3);
|
||||
.assert_approx_eq(&grad_2_slice_2.to_data(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -265,15 +265,15 @@ mod tests {
|
|||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -883,15 +883,15 @@ mod tests {
|
|||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -512,15 +512,15 @@ mod tests {
|
|||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -281,15 +281,15 @@ mod tests {
|
|||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -694,15 +694,15 @@ mod tests {
|
|||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -567,15 +567,15 @@ mod tests {
|
|||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 5);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1408,7 +1408,7 @@ mod tests {
|
|||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
|
||||
.assert_approx_eq_diff(&weight_grad_actual.to_data(), 0.04);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,11 +67,54 @@ mod transpose;
|
|||
|
||||
#[macro_export]
|
||||
macro_rules! testgen_all {
|
||||
// Avoid using paste dependency with no parameters
|
||||
() => {
|
||||
mod autodiff {
|
||||
pub use super::*;
|
||||
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
|
||||
|
||||
// Behavior
|
||||
pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
|
||||
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
|
||||
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;
|
||||
|
||||
$crate::testgen_with_float_param!();
|
||||
}
|
||||
};
|
||||
([$($float:ident),*]) => {
|
||||
mod autodiff {
|
||||
pub use super::*;
|
||||
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
|
||||
|
||||
pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
|
||||
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
|
||||
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;
|
||||
|
||||
::paste::paste! {
|
||||
$(mod [<$float _ty>] {
|
||||
pub use super::*;
|
||||
|
||||
pub type TestBackend = TestBackend2<$float, IntType>;
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
pub type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
|
||||
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
|
||||
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
|
||||
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;
|
||||
|
||||
type FloatType = $float;
|
||||
|
||||
$crate::testgen_with_float_param!();
|
||||
})*
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! testgen_with_float_param {
|
||||
() => {
|
||||
// Behaviour
|
||||
burn_autodiff::testgen_ad_broadcast!();
|
||||
burn_autodiff::testgen_gradients!();
|
||||
burn_autodiff::testgen_bridge!();
|
||||
|
|
|
@ -20,9 +20,9 @@ mod tests {
|
|||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[82.1126, 99.0832], [82.1126, 99.0832]]);
|
||||
grad_1.to_data().assert_approx_eq(&expected, 3);
|
||||
grad_1.to_data().assert_approx_eq_diff(&expected, 0.02);
|
||||
|
||||
let expected = TensorData::from([[30.3093, 33.1204], [34.5819, 38.7694]]);
|
||||
grad_2.to_data().assert_approx_eq(&expected, 3);
|
||||
grad_2.to_data().assert_approx_eq(&expected, 2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,10 +21,10 @@ mod tests {
|
|||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), false);
|
||||
.assert_approx_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), 3);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), false);
|
||||
.assert_approx_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -48,13 +48,13 @@ mod tests {
|
|||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
grad_1.to_data().assert_approx_eq(
|
||||
&TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),
|
||||
false,
|
||||
3,
|
||||
);
|
||||
grad_2.to_data().assert_eq(
|
||||
grad_2.to_data().assert_approx_eq(
|
||||
&TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),
|
||||
false,
|
||||
3,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,8 @@ mod tests {
|
|||
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
|
||||
|
||||
pub type FloatType = f32;
|
||||
|
||||
// test activation
|
||||
burn_tensor::testgen_gelu!();
|
||||
burn_tensor::testgen_prelu!();
|
||||
|
|
|
@ -523,7 +523,7 @@ mod tests {
|
|||
// Should produce the same tokens.
|
||||
output_1
|
||||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
.assert_approx_eq(&output_2.into_data(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -441,7 +441,7 @@ mod tests {
|
|||
|
||||
output_1
|
||||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
.assert_approx_eq(&output_2.into_data(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -2,38 +2,43 @@
|
|||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "CUDA backend for the Burn framework"
|
||||
documentation = "https://docs.rs/burn-cuda"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "gpu", "cuda"]
|
||||
license.workspace = true
|
||||
name = "burn-cuda"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
|
||||
documentation = "https://docs.rs/burn-cuda"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["fusion", "burn-jit/default", "cubecl/default"]
|
||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||
autotune = ["burn-jit/autotune"]
|
||||
default = ["fusion", "burn-jit/default", "cubecl/default"]
|
||||
doc = ["burn-jit/doc"]
|
||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||
std = ["burn-jit/std", "cubecl/std"]
|
||||
|
||||
[dependencies]
|
||||
cubecl = { workspace = true, features = ["cuda"] }
|
||||
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = ["cubecl-cuda"] }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
|
||||
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
|
||||
"cubecl-cuda",
|
||||
] }
|
||||
cubecl = { workspace = true, features = ["cuda"] }
|
||||
|
||||
half = { workspace = true }
|
||||
bytemuck = { workspace = true }
|
||||
half = { workspace = true }
|
||||
|
||||
log = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
paste = { workspace = true }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -17,6 +17,7 @@ mod tests {
|
|||
use burn_jit::JitBackend;
|
||||
|
||||
pub type TestRuntime = cubecl::cuda::CudaRuntime;
|
||||
pub use half::{bf16, f16};
|
||||
|
||||
burn_jit::testgen_all!();
|
||||
burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64]);
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ fn speech_command() {
|
|||
|
||||
println!("Item: {:?}", item);
|
||||
println!("Item Length: {:?}", item.audio_samples.len());
|
||||
println!("Label: {}", item.label.to_string());
|
||||
println!("Label: {}", item.label);
|
||||
|
||||
assert_eq!(test.len(), 4890);
|
||||
assert_eq!(item.label.to_string(), "Yes");
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
|
||||
use hound::WavReader;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumCount;
|
||||
use strum::EnumCount as _;
|
||||
use strum_macros::{Display, EnumCount, FromRepr};
|
||||
|
||||
type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples, SpeechItemRaw>;
|
||||
|
|
|
@ -367,7 +367,7 @@ mod tests {
|
|||
let dataset: DataframeDataset<TestData> = DataframeDataset::new(df).unwrap();
|
||||
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert_eq!(dataset.is_empty(), false);
|
||||
assert!(!dataset.is_empty());
|
||||
|
||||
let item = dataset.get(1).unwrap();
|
||||
assert_eq!(
|
||||
|
@ -439,7 +439,7 @@ mod tests {
|
|||
let dataset = DataframeDataset::<PartialTestData>::new(df).unwrap();
|
||||
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert_eq!(dataset.is_empty(), false);
|
||||
assert!(!dataset.is_empty());
|
||||
|
||||
let item = dataset.get(1).unwrap();
|
||||
assert_eq!(
|
||||
|
|
|
@ -8,6 +8,7 @@ use hashbrown::HashMap;
|
|||
///
|
||||
/// It also contains all scalar values, which can change even for the same graph. They are sorted
|
||||
/// in the order in which they appear in the graph.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[derive(new)]
|
||||
pub struct Context<'a, H> {
|
||||
/// The tensor mapping where local tensor id points to the updated tensor description.
|
||||
|
@ -20,8 +21,22 @@ pub struct Context<'a, H> {
|
|||
pub scalar_f16: &'a Vec<f16>,
|
||||
/// BF16 scalars found in the graph in the order they appeared.
|
||||
pub scalar_bf16: &'a Vec<bf16>,
|
||||
/// Int scalars found in the graph in the order they appeared.
|
||||
pub scalar_ints: &'a Vec<i32>,
|
||||
/// i64 scalars found in the graph in the order they appeared.
|
||||
pub scalar_i64: &'a Vec<i64>,
|
||||
/// i32 scalars found in the graph in the order they appeared.
|
||||
pub scalar_i32: &'a Vec<i32>,
|
||||
/// i16 scalars found in the graph in the order they appeared.
|
||||
pub scalar_i16: &'a Vec<i16>,
|
||||
/// i8 scalars found in the graph in the order they appeared.
|
||||
pub scalar_i8: &'a Vec<i8>,
|
||||
/// u64 scalars found in the graph in the order they appeared.
|
||||
pub scalar_u64: &'a Vec<u64>,
|
||||
/// u32 scalars found in the graph in the order they appeared.
|
||||
pub scalar_u32: &'a Vec<u32>,
|
||||
/// u16 scalars found in the graph in the order they appeared.
|
||||
pub scalar_u16: &'a Vec<u16>,
|
||||
/// u8 scalars found in the graph in the order they appeared.
|
||||
pub scalar_u8: &'a Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -34,7 +49,14 @@ pub(crate) struct OperationConverter {
|
|||
scalar_f32: Vec<f32>,
|
||||
scalar_f16: Vec<f16>,
|
||||
scalar_bf16: Vec<bf16>,
|
||||
scalar_ints: Vec<i32>,
|
||||
scalar_i64: Vec<i64>,
|
||||
scalar_i32: Vec<i32>,
|
||||
scalar_i16: Vec<i16>,
|
||||
scalar_i8: Vec<i8>,
|
||||
scalar_u64: Vec<u64>,
|
||||
scalar_u32: Vec<u32>,
|
||||
scalar_u16: Vec<u16>,
|
||||
scalar_u8: Vec<u8>,
|
||||
}
|
||||
|
||||
pub(crate) trait RelativeOps {
|
||||
|
@ -63,7 +85,14 @@ impl OperationConverter {
|
|||
scalar_f32: &self.scalar_f32,
|
||||
scalar_f16: &self.scalar_f16,
|
||||
scalar_bf16: &self.scalar_bf16,
|
||||
scalar_ints: &self.scalar_ints,
|
||||
scalar_i64: &self.scalar_i64,
|
||||
scalar_i32: &self.scalar_i32,
|
||||
scalar_i16: &self.scalar_i16,
|
||||
scalar_i8: &self.scalar_i8,
|
||||
scalar_u64: &self.scalar_u64,
|
||||
scalar_u32: &self.scalar_u32,
|
||||
scalar_u16: &self.scalar_u16,
|
||||
scalar_u8: &self.scalar_u8,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,7 +103,14 @@ impl OperationConverter {
|
|||
self.scalar_f32.clear();
|
||||
self.scalar_f16.clear();
|
||||
self.scalar_bf16.clear();
|
||||
self.scalar_ints.clear();
|
||||
self.scalar_i64.clear();
|
||||
self.scalar_i32.clear();
|
||||
self.scalar_i16.clear();
|
||||
self.scalar_i8.clear();
|
||||
self.scalar_u64.clear();
|
||||
self.scalar_u32.clear();
|
||||
self.scalar_u16.clear();
|
||||
self.scalar_u8.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E, dtype: &DType) -> E {
|
||||
|
@ -90,8 +126,18 @@ impl OperationConverter {
|
|||
0.elem()
|
||||
}
|
||||
|
||||
pub(crate) fn relative_int<E: Element>(&mut self, elem: &E) -> E {
|
||||
self.scalar_ints.push(elem.elem());
|
||||
pub(crate) fn relative_int<E: Element>(&mut self, elem: &E, dtype: &DType) -> E {
|
||||
match dtype {
|
||||
DType::I64 => self.scalar_i64.push(elem.elem()),
|
||||
DType::I32 => self.scalar_i32.push(elem.elem()),
|
||||
DType::I16 => self.scalar_i16.push(elem.elem()),
|
||||
DType::I8 => self.scalar_i8.push(elem.elem()),
|
||||
DType::U64 => self.scalar_u64.push(elem.elem()),
|
||||
DType::U32 => self.scalar_u32.push(elem.elem()),
|
||||
DType::U16 => self.scalar_u16.push(elem.elem()),
|
||||
DType::U8 => self.scalar_u8.push(elem.elem()),
|
||||
_ => todo!("Unsupported"),
|
||||
}
|
||||
// We return 0 so that the id from a scalar operation is the same no matter its scalar
|
||||
// value.
|
||||
0.elem()
|
||||
|
@ -116,7 +162,7 @@ impl RelativeOps for OperationDescription {
|
|||
),
|
||||
OperationDescription::NumericInt(dtype, ops) => OperationDescription::NumericInt(
|
||||
*dtype,
|
||||
ops.to_relative(converter, |converter, e| converter.relative_int(e)),
|
||||
ops.to_relative(converter, |converter, e| converter.relative_int(e, dtype)),
|
||||
),
|
||||
OperationDescription::Bool(ops) => {
|
||||
OperationDescription::Bool(ops.to_relative(converter))
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Generic backend that can be compiled just-in-time to any shader language target"
|
||||
documentation = "https://docs.rs/burn-jit"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "gpu"]
|
||||
license.workspace = true
|
||||
name = "burn-jit"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit"
|
||||
documentation = "https://docs.rs/burn-jit"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
|
@ -22,6 +22,7 @@ export_tests = [
|
|||
"burn-tensor/export_tests",
|
||||
"burn-ndarray",
|
||||
"fusion",
|
||||
"paste",
|
||||
]
|
||||
fusion = ["burn-fusion"]
|
||||
std = ["cubecl/std"]
|
||||
|
@ -31,7 +32,8 @@ template = []
|
|||
burn-common = { path = "../burn-common", version = "0.16.0" }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
|
||||
"cubecl", "repr",
|
||||
"cubecl",
|
||||
"repr",
|
||||
] }
|
||||
cubecl = { workspace = true, features = ["linalg"] }
|
||||
|
||||
|
@ -56,6 +58,7 @@ hashbrown = { workspace = true }
|
|||
# When exporting tests
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, optional = true }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true }
|
||||
paste = { workspace = true, optional = true }
|
||||
serial_test = { workspace = true, optional = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use cubecl::{
|
||||
flex32,
|
||||
prelude::{Float, Int, Numeric},
|
||||
CubeElement,
|
||||
};
|
||||
|
@ -12,17 +13,26 @@ pub trait FloatElement: JitElement + Float {}
|
|||
/// The int element type for the jit backend.
|
||||
pub trait IntElement: JitElement + Int {}
|
||||
|
||||
impl JitElement for u64 {}
|
||||
impl JitElement for u32 {}
|
||||
|
||||
impl JitElement for u16 {}
|
||||
impl JitElement for u8 {}
|
||||
impl JitElement for i64 {}
|
||||
impl JitElement for i32 {}
|
||||
|
||||
impl JitElement for i16 {}
|
||||
impl JitElement for i8 {}
|
||||
impl JitElement for f64 {}
|
||||
impl JitElement for f32 {}
|
||||
|
||||
impl JitElement for flex32 {}
|
||||
impl JitElement for half::f16 {}
|
||||
|
||||
impl JitElement for half::bf16 {}
|
||||
|
||||
impl FloatElement for f64 {}
|
||||
impl FloatElement for f32 {}
|
||||
impl FloatElement for flex32 {}
|
||||
impl FloatElement for half::bf16 {}
|
||||
impl FloatElement for half::f16 {}
|
||||
impl IntElement for i64 {}
|
||||
impl IntElement for i32 {}
|
||||
impl IntElement for i16 {}
|
||||
impl IntElement for i8 {}
|
||||
|
|
|
@ -68,15 +68,29 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseOptimization<R> {
|
|||
Arg::Input(index, precision, _) => match precision {
|
||||
ElemwisePrecision::F32 => inputs.t_f32.values.get(index as usize),
|
||||
ElemwisePrecision::F16 => inputs.t_f16.values.get(index as usize),
|
||||
ElemwisePrecision::BF16 => inputs.t_bf16.values.get(index as usize),
|
||||
ElemwisePrecision::U64 => inputs.t_u64.values.get(index as usize),
|
||||
ElemwisePrecision::U32 => inputs.t_u32.values.get(index as usize),
|
||||
ElemwisePrecision::U16 => inputs.t_u16.values.get(index as usize),
|
||||
ElemwisePrecision::U8 => inputs.t_u8.values.get(index as usize),
|
||||
ElemwisePrecision::I64 => inputs.t_i64.values.get(index as usize),
|
||||
ElemwisePrecision::I32 => inputs.t_i32.values.get(index as usize),
|
||||
ElemwisePrecision::I16 => inputs.t_i16.values.get(index as usize),
|
||||
ElemwisePrecision::I8 => inputs.t_i8.values.get(index as usize),
|
||||
_ => panic!("Invalid value"),
|
||||
},
|
||||
Arg::Output(index, precision, _) => match precision {
|
||||
ElemwisePrecision::F32 => outputs.t_f32.values.get(index as usize),
|
||||
ElemwisePrecision::F16 => outputs.t_f16.values.get(index as usize),
|
||||
ElemwisePrecision::BF16 => outputs.t_bf16.values.get(index as usize),
|
||||
ElemwisePrecision::U64 => outputs.t_u64.values.get(index as usize),
|
||||
ElemwisePrecision::U32 => outputs.t_u32.values.get(index as usize),
|
||||
ElemwisePrecision::U16 => outputs.t_u16.values.get(index as usize),
|
||||
ElemwisePrecision::U8 => outputs.t_u8.values.get(index as usize),
|
||||
ElemwisePrecision::I64 => outputs.t_i64.values.get(index as usize),
|
||||
ElemwisePrecision::I32 => outputs.t_i32.values.get(index as usize),
|
||||
ElemwisePrecision::I16 => outputs.t_i16.values.get(index as usize),
|
||||
ElemwisePrecision::I8 => outputs.t_i8.values.get(index as usize),
|
||||
_ => panic!("Invalid value"),
|
||||
},
|
||||
_ => panic!("Invalid value"),
|
||||
|
@ -142,11 +156,11 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseOptimization<R> {
|
|||
let mut output = u8::MAX;
|
||||
|
||||
for (handle, tensor) in handles_inputs.zip(inputs) {
|
||||
output = u8::min(vectorization_input(handle, tensor), output);
|
||||
output = Ord::min(vectorization_input(handle, tensor), output);
|
||||
}
|
||||
|
||||
for tensor in outputs {
|
||||
output = u8::min(vectorization_output(tensor), output);
|
||||
output = Ord::min(vectorization_output(tensor), output);
|
||||
}
|
||||
|
||||
output
|
||||
|
@ -168,15 +182,29 @@ fn elemwise_fuse(
|
|||
Arg::Input(index, precision, _) => match comptime![precision] {
|
||||
ElemwisePrecision::F32 => inputs.t_f32.index(index).len(),
|
||||
ElemwisePrecision::F16 => inputs.t_f16.index(index).len(),
|
||||
ElemwisePrecision::BF16 => inputs.t_bf16.index(index).len(),
|
||||
ElemwisePrecision::U64 => inputs.t_u64.index(index).len(),
|
||||
ElemwisePrecision::U32 => inputs.t_u32.index(index).len(),
|
||||
ElemwisePrecision::U16 => inputs.t_u16.index(index).len(),
|
||||
ElemwisePrecision::U8 => inputs.t_u8.index(index).len(),
|
||||
ElemwisePrecision::I64 => inputs.t_i64.index(index).len(),
|
||||
ElemwisePrecision::I32 => inputs.t_i32.index(index).len(),
|
||||
ElemwisePrecision::I16 => inputs.t_i16.index(index).len(),
|
||||
ElemwisePrecision::I8 => inputs.t_i8.index(index).len(),
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Output(index, precision, _) => match comptime![precision] {
|
||||
ElemwisePrecision::F32 => outputs.t_f32.index(index).len(),
|
||||
ElemwisePrecision::F16 => outputs.t_f16.index(index).len(),
|
||||
ElemwisePrecision::BF16 => outputs.t_bf16.index(index).len(),
|
||||
ElemwisePrecision::U64 => outputs.t_u64.index(index).len(),
|
||||
ElemwisePrecision::U32 => outputs.t_u32.index(index).len(),
|
||||
ElemwisePrecision::U16 => outputs.t_u16.index(index).len(),
|
||||
ElemwisePrecision::U8 => outputs.t_u8.index(index).len(),
|
||||
ElemwisePrecision::I64 => outputs.t_i64.index(index).len(),
|
||||
ElemwisePrecision::I32 => outputs.t_i32.index(index).len(),
|
||||
ElemwisePrecision::I16 => outputs.t_i16.index(index).len(),
|
||||
ElemwisePrecision::I8 => outputs.t_i8.index(index).len(),
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
_ => comptime![panic!("Invalid ref layout.")],
|
||||
|
|
|
@ -1,337 +0,0 @@
|
|||
use cubecl::calculate_num_elems_dyn_rank;
|
||||
use cubecl::prelude::*;
|
||||
|
||||
use crate::fusion::strides_dyn_rank;
|
||||
use crate::fusion::JitFusionHandle;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::JitRuntime;
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_tensor::repr::TensorDescription;
|
||||
use burn_tensor::repr::TensorStatus;
|
||||
use cubecl::client::ComputeClient;
|
||||
use cubecl::server::Binding;
|
||||
use cubecl::tune::AutotuneOperation;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::tracing::ExecutionInfo;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct FusionKernel<R: JitRuntime> {
|
||||
id: u64, // Same ID for all different settings.
|
||||
info: Arc<KernelExpansion>,
|
||||
settings: KernelSettings,
|
||||
runtime_info: Vec<OutputRuntimeInfo>,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
||||
pub trait FusionKernelFactory<R: JitRuntime> {
|
||||
/// Create a new kernel.
|
||||
fn create(
|
||||
&self,
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
stateful: bool, // Should be set to false when running autotune.
|
||||
) -> FusionKernel<R>;
|
||||
}
|
||||
|
||||
/// An instantiation of a [kernel](Kernel) that can be executed.
|
||||
#[derive(new)]
|
||||
pub struct ExecutableKernel<R: JitRuntime> {
|
||||
kernel: Box<dyn CubeTask<R::Compiler>>,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
bindings: Vec<Binding<R::Server>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
}
|
||||
|
||||
/// An instantiation of a [kernel](Kernel) that can be autotuned.
|
||||
///
|
||||
/// The main difference with an [executable kernel](ExecutableKernel) is that this kernel can be
|
||||
/// cloned and executed multiple times to properly collect benchmarks.
|
||||
///
|
||||
/// The clone function used is defined in the trait [AutotuneOperation] instead of [Clone].
|
||||
#[derive(new)]
|
||||
pub struct AutotunableKernel<R: JitRuntime> {
|
||||
kernel: Arc<dyn CubeTask<R::Compiler>>,
|
||||
count: CubeCount<R::Server>,
|
||||
bindings: Vec<Binding<R::Server>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
}
|
||||
|
||||
// Information related to the output of this kernel.
|
||||
#[derive(Debug)]
|
||||
pub enum OutputRuntimeInfo {
|
||||
Inplace { input_index: usize },
|
||||
Array { size: usize },
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> ExecutableKernel<R> {
|
||||
/// Execute the kernel.
|
||||
pub fn execute(self) {
|
||||
unsafe {
|
||||
self.client
|
||||
.execute_unchecked(self.kernel, self.cube_count, self.bindings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
|
||||
fn execute(self: Box<Self>) {
|
||||
self.client
|
||||
.execute(Box::new(self.kernel), self.count, self.bindings)
|
||||
}
|
||||
|
||||
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
||||
Box::new(Self {
|
||||
kernel: self.kernel.clone(),
|
||||
count: self.count.clone(),
|
||||
bindings: self.bindings.clone(),
|
||||
client: self.client.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
|
||||
fn from(value: ExecutableKernel<R>) -> Self {
|
||||
Self {
|
||||
kernel: Arc::new(value.kernel),
|
||||
count: value.cube_count.clone(),
|
||||
bindings: value.bindings,
|
||||
client: value.client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> FusionKernel<R> {
|
||||
pub fn create<K>(
|
||||
factory: &K,
|
||||
running_info: &ExecutionInfo<'_>,
|
||||
context: &mut Context<'_, JitFusionHandle<R>>,
|
||||
device: R::Device,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
stateful: bool,
|
||||
) -> ExecutableKernel<R>
|
||||
where
|
||||
K: FusionKernelFactory<R>,
|
||||
{
|
||||
let (handles_input, inputs_description_updated, outputs_description_updated) =
|
||||
process_inputs_outputs(
|
||||
&running_info.inputs,
|
||||
&running_info.outputs,
|
||||
context,
|
||||
stateful,
|
||||
);
|
||||
|
||||
let fusion_kernel = factory.create(
|
||||
&handles_input,
|
||||
&inputs_description_updated,
|
||||
&outputs_description_updated,
|
||||
stateful,
|
||||
);
|
||||
|
||||
let rank_input = running_info
|
||||
.inputs
|
||||
.first()
|
||||
.map(|desc| desc.shape.len())
|
||||
.unwrap_or(1);
|
||||
let rank_output = running_info
|
||||
.outputs
|
||||
.first()
|
||||
.map(|desc| desc.shape.len())
|
||||
.unwrap_or(1);
|
||||
let rank = usize::max(rank_input, rank_output);
|
||||
|
||||
let num_tensors = running_info.inputs.len() + running_info.outputs.len();
|
||||
// The buffer starts with the rank, then each tensor shape and stride.
|
||||
let info_size = (num_tensors * rank * 2) + 1;
|
||||
|
||||
let mut num_handles = num_tensors + 1;
|
||||
if running_info.scalars.num_f32 > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if running_info.scalars.num_f16 > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if running_info.scalars.num_bf16 > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if running_info.scalars.num_int > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
|
||||
let mut info = Vec::with_capacity(info_size);
|
||||
let mut bindings = Vec::with_capacity(num_handles);
|
||||
let mut output_register = Vec::with_capacity(outputs_description_updated.len());
|
||||
|
||||
// We register the info and handles for the inputs.
|
||||
for (handle, tensor) in handles_input.iter().zip(inputs_description_updated.iter()) {
|
||||
register_info_tensor(&mut info, tensor, handle);
|
||||
bindings.push(handle.handle.clone().binding());
|
||||
}
|
||||
|
||||
// We register the info and handles for the outputs.
|
||||
for (tensor, output_info) in outputs_description_updated
|
||||
.iter()
|
||||
.zip(fusion_kernel.runtime_info.iter())
|
||||
{
|
||||
match output_info {
|
||||
// Use the input inplace for this output.
|
||||
OutputRuntimeInfo::Inplace { input_index } => {
|
||||
let input = handles_input.get(*input_index).unwrap();
|
||||
|
||||
let handle_fusion = JitFusionHandle {
|
||||
client: client.clone(),
|
||||
device: device.clone(),
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
handle: input.handle.clone(),
|
||||
};
|
||||
output_register.push((tensor.id, handle_fusion));
|
||||
}
|
||||
// Create a new buffer for this output.
|
||||
OutputRuntimeInfo::Array { size } => {
|
||||
let handle_fusion = JitFusionHandle {
|
||||
client: client.clone(),
|
||||
device: device.clone(),
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
handle: client.empty(*size),
|
||||
};
|
||||
|
||||
register_info_tensor(&mut info, tensor, &handle_fusion);
|
||||
bindings.push(handle_fusion.handle.clone().binding());
|
||||
output_register.push((tensor.id, handle_fusion));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len]
|
||||
if R::require_array_lengths() {
|
||||
for input in inputs_description_updated.iter() {
|
||||
let len = calculate_num_elems_dyn_rank(&input.shape);
|
||||
info.push(len as u32);
|
||||
}
|
||||
|
||||
for output in outputs_description_updated.iter() {
|
||||
let len = calculate_num_elems_dyn_rank(&output.shape);
|
||||
info.push(len as u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Create the info buffer.
|
||||
bindings.push(client.create(bytemuck::cast_slice(&info)).binding());
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
if running_info.scalars.num_f32 > 0 {
|
||||
let bytes = bytemuck::cast_slice(&context.scalar_f32[0..running_info.scalars.num_f32]);
|
||||
bindings.push(client.create(bytes).binding());
|
||||
}
|
||||
|
||||
if running_info.scalars.num_f16 > 0 {
|
||||
let bytes = bytemuck::cast_slice(&context.scalar_f16[0..running_info.scalars.num_f16]);
|
||||
bindings.push(client.create(bytes).binding());
|
||||
}
|
||||
|
||||
if running_info.scalars.num_bf16 > 0 {
|
||||
let bytes =
|
||||
bytemuck::cast_slice(&context.scalar_bf16[0..running_info.scalars.num_bf16]);
|
||||
bindings.push(client.create(bytes).binding());
|
||||
}
|
||||
|
||||
if running_info.scalars.num_int > 0 {
|
||||
bindings.push(
|
||||
client
|
||||
.create(bytemuck::cast_slice(
|
||||
&context.scalar_ints[0..running_info.scalars.num_int],
|
||||
))
|
||||
.binding(),
|
||||
);
|
||||
}
|
||||
|
||||
// We have to register the output handles to the context.
|
||||
for (id, handle) in output_register {
|
||||
context.handles.register_handle(id, handle);
|
||||
}
|
||||
|
||||
let cube_count = fusion_kernel.cube_count.clone();
|
||||
ExecutableKernel::new(
|
||||
Box::new(KernelTask::<R::Compiler, _>::new(fusion_kernel)),
|
||||
cube_count,
|
||||
bindings,
|
||||
client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> Kernel for FusionKernel<R> {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
log::info!("Compiling ... {:?}", self.id());
|
||||
KernelIntegrator::new(self.info.as_ref().clone()).integrate(self.settings.clone())
|
||||
}
|
||||
|
||||
fn id(&self) -> cubecl::KernelId {
|
||||
cubecl::KernelId::new::<Self>().info((self.settings.clone(), self.id))
|
||||
}
|
||||
}
|
||||
|
||||
fn register_info_tensor<R: JitRuntime>(
|
||||
info: &mut Vec<u32>,
|
||||
tensor: &TensorDescription,
|
||||
handle: &JitFusionHandle<R>,
|
||||
) {
|
||||
if info.is_empty() {
|
||||
info.push(handle.strides.len() as u32);
|
||||
}
|
||||
|
||||
for s in handle.strides.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
for s in tensor.shape.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
}
|
||||
|
||||
fn process_inputs_outputs<'a, R>(
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
context: &'a mut Context<'_, JitFusionHandle<R>>,
|
||||
stateful: bool,
|
||||
) -> (
|
||||
Vec<JitFusionHandle<R>>,
|
||||
Vec<&'a TensorDescription>,
|
||||
Vec<&'a TensorDescription>,
|
||||
)
|
||||
where
|
||||
R: JitRuntime,
|
||||
{
|
||||
let mut inputs_description_updated = Vec::with_capacity(inputs.len());
|
||||
let mut outputs_description_updated = Vec::with_capacity(outputs.len());
|
||||
let mut handles_input = Vec::new();
|
||||
|
||||
for tensor in inputs.iter() {
|
||||
let status = if stateful {
|
||||
&tensor.status // Important to take the status of the relative graph and not
|
||||
// the global graph, since the status of the global graph
|
||||
// might be of a later operation on the same tensor id.
|
||||
} else {
|
||||
&TensorStatus::ReadOnly
|
||||
};
|
||||
|
||||
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||
let handle = context.handles.get_handle(&tensor.id, status);
|
||||
|
||||
handles_input.push(handle);
|
||||
inputs_description_updated.push(tensor);
|
||||
}
|
||||
|
||||
for tensor in outputs.iter() {
|
||||
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||
outputs_description_updated.push(tensor);
|
||||
}
|
||||
|
||||
(
|
||||
handles_input,
|
||||
inputs_description_updated,
|
||||
outputs_description_updated,
|
||||
)
|
||||
}
|
|
@ -31,6 +31,24 @@ pub fn read<C: CubePrimitive>(
|
|||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::BF16 => {
|
||||
let tensor = inputs.t_bf16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
let tensor = inputs.t_u64.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
let tensor = inputs.t_u32.index(pos);
|
||||
let offset = match layout {
|
||||
|
@ -40,6 +58,33 @@ pub fn read<C: CubePrimitive>(
|
|||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
let tensor = inputs.t_u16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
let tensor = inputs.t_u8.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
let tensor = inputs.t_i64.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
let tensor = inputs.t_i32.index(pos);
|
||||
let offset = match layout {
|
||||
|
@ -49,6 +94,24 @@ pub fn read<C: CubePrimitive>(
|
|||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
let tensor = inputs.t_i16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
let tensor = inputs.t_i8.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Output(pos, precision, layout) => match comptime![precision] {
|
||||
|
@ -70,6 +133,24 @@ pub fn read<C: CubePrimitive>(
|
|||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::BF16 => {
|
||||
let tensor = outputs.t_bf16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
let tensor = outputs.t_u64.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
let tensor = outputs.t_u32.index(pos);
|
||||
let offset = match layout {
|
||||
|
@ -79,6 +160,33 @@ pub fn read<C: CubePrimitive>(
|
|||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
let tensor = outputs.t_u16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
let tensor = outputs.t_u8.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
let tensor = outputs.t_i64.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
let tensor = outputs.t_i32.index(pos);
|
||||
let offset = match layout {
|
||||
|
@ -88,22 +196,52 @@ pub fn read<C: CubePrimitive>(
|
|||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
let tensor = outputs.t_i16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
let tensor = outputs.t_i8.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
Line::cast_from(tensor[offset])
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Local(pos, precision) => match comptime![precision] {
|
||||
ElemwisePrecision::F32 => Line::cast_from(locals.l_f32.find(pos)),
|
||||
ElemwisePrecision::F16 => Line::cast_from(locals.l_f16.find(pos)),
|
||||
ElemwisePrecision::BF16 => Line::cast_from(locals.l_bf16.find(pos)),
|
||||
ElemwisePrecision::U64 => Line::cast_from(locals.l_u64.find(pos)),
|
||||
ElemwisePrecision::U32 => Line::cast_from(locals.l_u32.find(pos)),
|
||||
ElemwisePrecision::U16 => Line::cast_from(locals.l_u16.find(pos)),
|
||||
ElemwisePrecision::U8 => Line::cast_from(locals.l_u8.find(pos)),
|
||||
ElemwisePrecision::I64 => Line::cast_from(locals.l_i64.find(pos)),
|
||||
ElemwisePrecision::I32 => Line::cast_from(locals.l_i32.find(pos)),
|
||||
ElemwisePrecision::I16 => Line::cast_from(locals.l_i16.find(pos)),
|
||||
ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)),
|
||||
ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)),
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Scalar(pos, precision) => match comptime![precision] {
|
||||
ElemwisePrecision::F32 => Line::cast_from(*inputs.s_f32.index(pos)),
|
||||
ElemwisePrecision::F16 => Line::cast_from(*inputs.s_f16.index(pos)),
|
||||
ElemwisePrecision::BF16 => Line::cast_from(*inputs.s_bf16.index(pos)),
|
||||
ElemwisePrecision::U64 => Line::cast_from(*inputs.s_u64.index(pos)),
|
||||
ElemwisePrecision::U32 => Line::cast_from(*inputs.s_u32.index(pos)),
|
||||
ElemwisePrecision::U16 => Line::cast_from(*inputs.s_u16.index(pos)),
|
||||
ElemwisePrecision::U8 => Line::cast_from(*inputs.s_u8.index(pos)),
|
||||
ElemwisePrecision::I64 => Line::cast_from(*inputs.s_i64.index(pos)),
|
||||
ElemwisePrecision::I32 => Line::cast_from(*inputs.s_i32.index(pos)),
|
||||
ElemwisePrecision::BF16 => comptime![panic!("Can't write into inputs or scalars")],
|
||||
ElemwisePrecision::I16 => Line::cast_from(*inputs.s_i16.index(pos)),
|
||||
ElemwisePrecision::I8 => Line::cast_from(*inputs.s_i8.index(pos)),
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Literal(val, _precision) => Line::cast_from(val.runtime()),
|
||||
|
@ -143,6 +281,26 @@ pub fn write<C: CubePrimitive>(
|
|||
let tensor = outputs.t_f16.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::BF16 => {
|
||||
let tensor = outputs.t_bf16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_bf16.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
let tensor = outputs.t_u64.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_u64.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
let tensor = outputs.t_u32.index(pos);
|
||||
let offset = match layout {
|
||||
|
@ -153,6 +311,36 @@ pub fn write<C: CubePrimitive>(
|
|||
let tensor = outputs.t_u32.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
let tensor = outputs.t_u16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_u16.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
let tensor = outputs.t_u8.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_u8.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
let tensor = outputs.t_i64.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_i64.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
let tensor = outputs.t_i32.index(pos);
|
||||
let offset = match layout {
|
||||
|
@ -163,15 +351,41 @@ pub fn write<C: CubePrimitive>(
|
|||
let tensor = outputs.t_i32.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
let tensor = outputs.t_i16.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_i16.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
let tensor = outputs.t_i8.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config),
|
||||
};
|
||||
let tensor = outputs.t_i8.index_mut(pos);
|
||||
tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Local(pos, precision) => match comptime![precision] {
|
||||
ElemwisePrecision::F32 => locals.l_f32.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::F16 => locals.l_f16.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::U64 => locals.l_u64.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::U32 => locals.l_u32.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::U16 => locals.l_u16.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::U8 => locals.l_u8.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::I64 => locals.l_i64.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::I32 => locals.l_i32.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::I16 => locals.l_i16.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::I8 => locals.l_i8.insert(pos, Line::cast_from(value)),
|
||||
ElemwisePrecision::Bool => locals.l_bool.insert(pos, Line::cast_from(value)),
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
_ => comptime![panic!("Can't write into inputs and scalars")],
|
||||
}
|
||||
|
@ -195,14 +409,42 @@ fn get_offset<C: CubePrimitive>(
|
|||
let layout = inputs.t_f16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::BF16 => {
|
||||
let layout = inputs.t_bf16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
let layout = inputs.t_u64.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
let layout = inputs.t_u32.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
let layout = inputs.t_u16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
let layout = inputs.t_u8.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
let layout = inputs.t_i64.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
let layout = inputs.t_i32.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
let layout = inputs.t_i16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
let layout = inputs.t_i8.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
Arg::Output(index, precision, _) => match comptime![precision] {
|
||||
|
@ -214,14 +456,42 @@ fn get_offset<C: CubePrimitive>(
|
|||
let layout = outputs.t_f16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::BF16 => {
|
||||
let layout = outputs.t_bf16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
let layout = outputs.t_u64.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
let layout = outputs.t_u32.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
let layout = outputs.t_u16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
let layout = outputs.t_u8.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
let layout = outputs.t_i64.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
let layout = outputs.t_i32.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
let layout = outputs.t_i16.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
let layout = outputs.t_i8.index(index);
|
||||
index_offset_with_layout(tensor, layout, pos, 0, config.rank, false)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {precision:?}")],
|
||||
},
|
||||
_ => comptime![panic!("Invalid ref layout.")],
|
||||
|
|
|
@ -79,13 +79,25 @@ pub struct GlobalArgs {
|
|||
pub t_f32: Sequence<Tensor<Line<f32>>>,
|
||||
pub t_f16: Sequence<Tensor<Line<f16>>>,
|
||||
pub t_bf16: Sequence<Tensor<Line<bf16>>>,
|
||||
pub t_i64: Sequence<Tensor<Line<i64>>>,
|
||||
pub t_i32: Sequence<Tensor<Line<i32>>>,
|
||||
pub t_i16: Sequence<Tensor<Line<i16>>>,
|
||||
pub t_i8: Sequence<Tensor<Line<i8>>>,
|
||||
pub t_u64: Sequence<Tensor<Line<u64>>>,
|
||||
pub t_u32: Sequence<Tensor<Line<u32>>>,
|
||||
pub t_u16: Sequence<Tensor<Line<u16>>>,
|
||||
pub t_u8: Sequence<Tensor<Line<u8>>>,
|
||||
pub s_f32: Sequence<f32>,
|
||||
pub s_f16: Sequence<f16>,
|
||||
pub s_bf16: Sequence<bf16>,
|
||||
pub s_i64: Sequence<i64>,
|
||||
pub s_i32: Sequence<i32>,
|
||||
pub s_i16: Sequence<i16>,
|
||||
pub s_i8: Sequence<i8>,
|
||||
pub s_u64: Sequence<u64>,
|
||||
pub s_u32: Sequence<u32>,
|
||||
pub s_u16: Sequence<u16>,
|
||||
pub s_u8: Sequence<u8>,
|
||||
}
|
||||
|
||||
#[derive(CubeType, Clone)]
|
||||
|
@ -95,8 +107,14 @@ pub struct LocalArgs {
|
|||
pub l_f32: Registry<u32, Line<f32>>,
|
||||
pub l_f16: Registry<u32, Line<f16>>,
|
||||
pub l_bf16: Registry<u32, Line<bf16>>,
|
||||
pub l_i64: Registry<u32, Line<i64>>,
|
||||
pub l_i32: Registry<u32, Line<i32>>,
|
||||
pub l_i16: Registry<u32, Line<i16>>,
|
||||
pub l_i8: Registry<u32, Line<i8>>,
|
||||
pub l_u64: Registry<u32, Line<u64>>,
|
||||
pub l_u32: Registry<u32, Line<u32>>,
|
||||
pub l_u16: Registry<u32, Line<u16>>,
|
||||
pub l_u8: Registry<u32, Line<u8>>,
|
||||
pub l_bool: Registry<u32, Line<bool>>,
|
||||
}
|
||||
|
||||
|
@ -123,9 +141,13 @@ pub enum ElemwisePrecision {
|
|||
F32,
|
||||
F16,
|
||||
BF16,
|
||||
I64,
|
||||
I32,
|
||||
I16,
|
||||
I8,
|
||||
U64,
|
||||
U32,
|
||||
U16,
|
||||
U8,
|
||||
Bool,
|
||||
}
|
||||
|
@ -139,8 +161,18 @@ impl From<Elem> for ElemwisePrecision {
|
|||
cubecl::ir::FloatKind::F32 => Self::F32,
|
||||
_ => panic!("Unsupported precision for fusion: {value}"),
|
||||
},
|
||||
Elem::Int(cubecl::ir::IntKind::I32) => Self::I32,
|
||||
Elem::UInt => Self::U32,
|
||||
Elem::Int(kind) => match kind {
|
||||
cubecl::ir::IntKind::I64 => Self::I64,
|
||||
cubecl::ir::IntKind::I32 => Self::I32,
|
||||
cubecl::ir::IntKind::I16 => Self::I16,
|
||||
cubecl::ir::IntKind::I8 => Self::I8,
|
||||
},
|
||||
Elem::UInt(kind) => match kind {
|
||||
cubecl::ir::UIntKind::U64 => Self::U64,
|
||||
cubecl::ir::UIntKind::U32 => Self::U32,
|
||||
cubecl::ir::UIntKind::U16 => Self::U16,
|
||||
cubecl::ir::UIntKind::U8 => Self::U8,
|
||||
},
|
||||
Elem::Bool => Self::Bool,
|
||||
_ => panic!("Unsupported precision for fusion: {value}"),
|
||||
}
|
||||
|
@ -153,9 +185,13 @@ impl From<DType> for ElemwisePrecision {
|
|||
DType::F32 => Self::F32,
|
||||
DType::F16 => Self::F16,
|
||||
DType::BF16 => Self::BF16,
|
||||
DType::I64 => Self::I64,
|
||||
DType::I32 => Self::I32,
|
||||
DType::I16 => Self::I16,
|
||||
DType::I8 => Self::I8,
|
||||
DType::U64 => Self::U64,
|
||||
DType::U32 => Self::U32,
|
||||
DType::U16 => Self::U16,
|
||||
DType::U8 => Self::U8,
|
||||
DType::Bool => Self::Bool,
|
||||
_ => panic!("Unsupported"),
|
||||
|
|
|
@ -19,8 +19,14 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
l_f32: Registry::<u32, Line<f32>>::new(),
|
||||
l_f16: Registry::<u32, Line<f16>>::new(),
|
||||
l_bf16: Registry::<u32, Line<bf16>>::new(),
|
||||
l_i64: Registry::<u32, Line<i64>>::new(),
|
||||
l_i32: Registry::<u32, Line<i32>>::new(),
|
||||
l_i16: Registry::<u32, Line<i16>>::new(),
|
||||
l_i8: Registry::<u32, Line<i8>>::new(),
|
||||
l_u64: Registry::<u32, Line<u64>>::new(),
|
||||
l_u32: Registry::<u32, Line<u32>>::new(),
|
||||
l_u16: Registry::<u32, Line<u16>>::new(),
|
||||
l_u8: Registry::<u32, Line<u8>>::new(),
|
||||
l_bool: Registry::<u32, Line<bool>>::new(),
|
||||
};
|
||||
|
||||
|
@ -48,12 +54,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
add::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
add::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
add::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
add::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
add::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
add::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
add::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
add::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
add::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Div(op) => match op.out.precision() {
|
||||
|
@ -66,12 +90,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
div::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
div::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
div::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
div::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
div::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
div::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
div::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
div::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
div::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Sub(op) => match op.out.precision() {
|
||||
|
@ -84,12 +126,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
sub::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
sub::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
sub::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
sub::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
sub::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
sub::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
sub::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
sub::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
sub::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Mul(op) => match op.out.precision() {
|
||||
|
@ -102,12 +162,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
mul::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
mul::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
mul::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
mul::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
mul::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
mul::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
mul::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
mul::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
mul::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Powf(op) => match op.out.precision() {
|
||||
|
@ -144,12 +222,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
abs::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
assign::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
assign::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
assign::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
assign::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
abs::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
abs::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
abs::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
abs::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Log(op) => match op.out.precision() {
|
||||
|
@ -198,16 +294,33 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
assign::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
assign::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
assign::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
assign::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
assign::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
assign::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
assign::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
assign::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
assign::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::Bool => {
|
||||
assign::<bool>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Exp(op) => match op.out.precision() {
|
||||
ElemwisePrecision::F32 => {
|
||||
|
@ -267,12 +380,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
equal::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
equal::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
equal::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
equal::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
equal::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
equal::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
equal::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
equal::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Greater(op) => match op.lhs.precision() {
|
||||
|
@ -285,12 +416,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
greater::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
greater::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
greater::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
greater::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
greater::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
greater::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
greater::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
greater::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
greater::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::GreaterEqual(op) => match op.lhs.precision() {
|
||||
|
@ -303,12 +452,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
greater_equal::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
greater_equal::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
greater_equal::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
greater_equal::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
greater_equal::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
greater_equal::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
greater_equal::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
greater_equal::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
greater_equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::Lower(op) => match op.lhs.precision() {
|
||||
|
@ -321,12 +488,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
lower::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
lower::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
lower::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
lower::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
lower::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
lower::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
lower::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
lower::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
lower::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::LowerEqual(op) => match op.lhs.precision() {
|
||||
|
@ -339,12 +524,30 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
ElemwisePrecision::BF16 => {
|
||||
lower_equal::<bf16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I64 => {
|
||||
lower_equal::<i64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
lower_equal::<i32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
lower_equal::<i16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::I8 => {
|
||||
lower_equal::<i8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U64 => {
|
||||
lower_equal::<u64>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
lower_equal::<u32>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
lower_equal::<u16>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
ElemwisePrecision::U8 => {
|
||||
lower_equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
|
||||
}
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
ElemwiseOp::ConditionalAssign {
|
||||
|
@ -386,6 +589,17 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::I64 => conditional_assign::<i64>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
write_pos,
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::I32 => conditional_assign::<i32>(
|
||||
inputs,
|
||||
outputs,
|
||||
|
@ -397,6 +611,39 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::I16 => conditional_assign::<i16>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
write_pos,
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::I8 => conditional_assign::<i8>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
write_pos,
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::U64 => conditional_assign::<u64>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
write_pos,
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::U32 => conditional_assign::<u32>(
|
||||
inputs,
|
||||
outputs,
|
||||
|
@ -408,6 +655,28 @@ pub fn fuse_on_write<E: CubePrimitive>(
|
|||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::U16 => conditional_assign::<u16>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
write_pos,
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
config,
|
||||
),
|
||||
ElemwisePrecision::U8 => conditional_assign::<u8>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
write_pos,
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
config,
|
||||
),
|
||||
_ => comptime![panic!("Unsupported precision {op:?}")],
|
||||
},
|
||||
}
|
||||
|
|
|
@ -333,6 +333,18 @@ impl FuseOnWriteTrace {
|
|||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
);
|
||||
|
||||
for hi in handle_inputs.iter() {
|
||||
|
@ -341,8 +353,14 @@ impl FuseOnWriteTrace {
|
|||
ElemwisePrecision::F32 => inputs.t_f32.push(arg),
|
||||
ElemwisePrecision::F16 => inputs.t_f16.push(arg),
|
||||
ElemwisePrecision::BF16 => inputs.t_bf16.push(arg),
|
||||
ElemwisePrecision::I64 => inputs.t_i64.push(arg),
|
||||
ElemwisePrecision::I32 => inputs.t_i32.push(arg),
|
||||
ElemwisePrecision::I16 => inputs.t_i16.push(arg),
|
||||
ElemwisePrecision::I8 => inputs.t_i8.push(arg),
|
||||
ElemwisePrecision::U64 => inputs.t_u64.push(arg),
|
||||
ElemwisePrecision::U32 => inputs.t_u32.push(arg),
|
||||
ElemwisePrecision::U16 => inputs.t_u16.push(arg),
|
||||
ElemwisePrecision::U8 => inputs.t_u8.push(arg),
|
||||
_ => panic!("Unsupported input precision {:?}", hi.precision),
|
||||
};
|
||||
}
|
||||
|
@ -356,10 +374,30 @@ impl FuseOnWriteTrace {
|
|||
ElemwisePrecision::F16 => {
|
||||
inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i]))
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
inputs.s_i32.push(ScalarArg::new(context.scalar_ints[i]))
|
||||
ElemwisePrecision::BF16 => {
|
||||
inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i]))
|
||||
}
|
||||
_ => todo!(),
|
||||
ElemwisePrecision::I64 => {
|
||||
inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i]))
|
||||
}
|
||||
ElemwisePrecision::I32 => {
|
||||
inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i]))
|
||||
}
|
||||
ElemwisePrecision::I16 => {
|
||||
inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i]))
|
||||
}
|
||||
ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])),
|
||||
ElemwisePrecision::U64 => {
|
||||
inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i]))
|
||||
}
|
||||
ElemwisePrecision::U32 => {
|
||||
inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i]))
|
||||
}
|
||||
ElemwisePrecision::U16 => {
|
||||
inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i]))
|
||||
}
|
||||
ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])),
|
||||
ElemwisePrecision::Bool => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -383,6 +421,18 @@ impl FuseOnWriteTrace {
|
|||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
SequenceArg::new(),
|
||||
);
|
||||
for item in handle_outputs.iter() {
|
||||
match item {
|
||||
|
@ -392,8 +442,15 @@ impl FuseOnWriteTrace {
|
|||
} => match precision {
|
||||
ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)),
|
||||
ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)),
|
||||
_ => todo!(),
|
||||
},
|
||||
HandleOutput::Owned {
|
||||
|
@ -406,11 +463,17 @@ impl FuseOnWriteTrace {
|
|||
match precision {
|
||||
ElemwisePrecision::F32 => outputs.t_f32.push(arg),
|
||||
ElemwisePrecision::F16 => outputs.t_f16.push(arg),
|
||||
ElemwisePrecision::BF16 => outputs.t_bf16.push(arg),
|
||||
ElemwisePrecision::I64 => outputs.t_i64.push(arg),
|
||||
ElemwisePrecision::I32 => outputs.t_i32.push(arg),
|
||||
ElemwisePrecision::I16 => outputs.t_i16.push(arg),
|
||||
ElemwisePrecision::I8 => outputs.t_i8.push(arg),
|
||||
ElemwisePrecision::U64 => outputs.t_u64.push(arg),
|
||||
ElemwisePrecision::U32 => outputs.t_u32.push(arg),
|
||||
ElemwisePrecision::U16 => outputs.t_u16.push(arg),
|
||||
ElemwisePrecision::U8 => outputs.t_u8.push(arg),
|
||||
// Bools are encoded as u32.
|
||||
ElemwisePrecision::Bool => outputs.t_u32.push(arg),
|
||||
_ => todo!(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,7 +127,7 @@ pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
|
|||
let vectorization_factor_rhs =
|
||||
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1);
|
||||
|
||||
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
|
||||
let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs);
|
||||
|
||||
let mut shape_out = vec![0; ndims];
|
||||
lhs.shape
|
||||
|
|
|
@ -119,7 +119,7 @@ pub(crate) fn launch_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(
|
|||
let vectorization_factor_rhs =
|
||||
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1);
|
||||
|
||||
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
|
||||
let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs);
|
||||
|
||||
let mut shape_out = vec![0; ndims];
|
||||
lhs.shape
|
||||
|
@ -169,7 +169,7 @@ pub(crate) fn launch_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(
|
|||
|
||||
JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides)
|
||||
} else {
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
|
||||
let output =
|
||||
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
|
||||
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
|
||||
|
@ -225,7 +225,7 @@ pub(crate) fn launch_scalar_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>
|
|||
tensor.strides,
|
||||
)
|
||||
} else {
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<u32>());
|
||||
let output = JitTensor::new(
|
||||
tensor.client.clone(),
|
||||
buffer,
|
||||
|
|
|
@ -60,7 +60,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
|
|||
options.groups,
|
||||
out_h,
|
||||
out_w,
|
||||
&input.device,
|
||||
&input.client,
|
||||
) {
|
||||
panic!(
|
||||
"Requirements for implicit GEMM not met:
|
||||
|
@ -84,7 +84,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
|
|||
let slice_size = pad_kh * pad_kw * pad_in_channels;
|
||||
|
||||
let (cmma_m, cmma_n, cmma_k) =
|
||||
find_cmma_size::<R, f16, F>(&input.device, gemm_m, gemm_k, gemm_n).unwrap();
|
||||
find_cmma_size::<R, f16, F>(&input.client, gemm_m, gemm_k, gemm_n).unwrap();
|
||||
|
||||
let cube_dim_x = 128;
|
||||
let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2);
|
||||
|
@ -187,7 +187,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
|
|||
fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 {
|
||||
let channels = channels as u8;
|
||||
let elems_per_thread = elems_per_thread as u8;
|
||||
let smaller = u8::min(channels, elems_per_thread);
|
||||
let smaller = Ord::min(channels, elems_per_thread);
|
||||
(1..=smaller)
|
||||
.rev()
|
||||
.filter(|it| supported_vecs.contains(it))
|
||||
|
@ -663,7 +663,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
|
|||
groups: usize,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
device: &R::Device,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> bool {
|
||||
let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_size[0], kernel_size[1]);
|
||||
let batch_size = padded_batch_size(batch_size, out_h, out_w);
|
||||
|
@ -673,7 +673,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
|
|||
let gemm_n = out_channels;
|
||||
let gemm_k = in_channels * kernel_h * kernel_w;
|
||||
|
||||
let size = find_cmma_size::<R, f16, E>(device, gemm_m as u32, gemm_k as u32, gemm_n as u32);
|
||||
let size = find_cmma_size::<R, f16, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);
|
||||
|
||||
if let Some((cmma_m, cmma_k, cmma_n)) = size {
|
||||
let warps_per_cube = 8;
|
||||
|
@ -716,12 +716,12 @@ fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize {
|
|||
}
|
||||
|
||||
fn find_cmma_size<R: JitRuntime, F: Float, FAcc: Float>(
|
||||
device: &R::JitDevice,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
gemm_m: u32,
|
||||
gemm_k: u32,
|
||||
gemm_n: u32,
|
||||
) -> Option<(u32, u32, u32)> {
|
||||
supported_cmma_sizes::<R, F, FAcc>(device)
|
||||
supported_cmma_sizes::<R, F, FAcc>(client)
|
||||
.into_iter()
|
||||
.find(|(m, k, n)| {
|
||||
gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0
|
||||
|
@ -730,7 +730,7 @@ fn find_cmma_size<R: JitRuntime, F: Float, FAcc: Float>(
|
|||
}
|
||||
|
||||
fn supported_cmma_sizes<R: JitRuntime, F: Float, FAcc: Float>(
|
||||
device: &R::JitDevice,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> Vec<(u8, u8, u8)> {
|
||||
let requested_sizes = [(16, 16, 16), (32, 16, 8), (8, 16, 32)];
|
||||
|
||||
|
@ -738,9 +738,7 @@ fn supported_cmma_sizes<R: JitRuntime, F: Float, FAcc: Float>(
|
|||
.iter()
|
||||
.copied()
|
||||
.filter(|(m, k, n)| {
|
||||
R::client(device)
|
||||
.properties()
|
||||
.feature_enabled(Feature::Cmma {
|
||||
client.properties().feature_enabled(Feature::Cmma {
|
||||
a: F::as_elem(),
|
||||
b: F::as_elem(),
|
||||
c: FAcc::as_elem(),
|
||||
|
|
|
@ -3,6 +3,7 @@ use cubecl::{
|
|||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -43,14 +44,14 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
let input_shape_0 = scope.create_local(Elem::UInt);
|
||||
let input_shape_1 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
let input_stride_0 = scope.create_local(u32::as_elem());
|
||||
let input_stride_1 = scope.create_local(u32::as_elem());
|
||||
let input_stride_2 = scope.create_local(u32::as_elem());
|
||||
let input_stride_3 = scope.create_local(u32::as_elem());
|
||||
let input_shape_0 = scope.create_local(u32::as_elem());
|
||||
let input_shape_1 = scope.create_local(u32::as_elem());
|
||||
let input_shape_2 = scope.create_local(u32::as_elem());
|
||||
let input_shape_3 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, input_stride_0 = stride(input, 0u32));
|
||||
cpa!(scope, input_stride_1 = stride(input, 1u32));
|
||||
cpa!(scope, input_stride_2 = stride(input, 2u32));
|
||||
|
@ -60,14 +61,14 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, input_shape_2 = shape(input, 2u32));
|
||||
cpa!(scope, input_shape_3 = shape(input, 3u32));
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, output_stride_0 = stride(output, 0u32));
|
||||
cpa!(scope, output_stride_1 = stride(output, 1u32));
|
||||
cpa!(scope, output_stride_2 = stride(output, 2u32));
|
||||
|
@ -77,14 +78,14 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let weight_stride_0 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_1 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_2 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_3 = scope.create_local(Elem::UInt);
|
||||
let in_channels = scope.create_local(Elem::UInt);
|
||||
let weight_shape_1 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_0 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_1 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_0 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_1 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_2 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_3 = scope.create_local(u32::as_elem());
|
||||
let in_channels = scope.create_local(u32::as_elem());
|
||||
let weight_shape_1 = scope.create_local(u32::as_elem());
|
||||
let kernel_size_0 = scope.create_local(u32::as_elem());
|
||||
let kernel_size_1 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, weight_stride_0 = stride(weight, 0u32));
|
||||
cpa!(scope, weight_stride_1 = stride(weight, 1u32));
|
||||
cpa!(scope, weight_stride_2 = stride(weight, 2u32));
|
||||
|
@ -94,31 +95,31 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
|
||||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
|
||||
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
|
||||
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
|
||||
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
|
||||
let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(u32::as_elem()));
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
cpa!(scope, stride_0_i = cast(conv_stride_0));
|
||||
cpa!(scope, stride_1_i = cast(conv_stride_1));
|
||||
|
||||
let oc_out = scope.create_local(Elem::UInt);
|
||||
let oc = scope.create_local(Elem::UInt);
|
||||
let oc_out = scope.create_local(u32::as_elem());
|
||||
let oc = scope.create_local(u32::as_elem());
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
let k = scope.create_local(Elem::UInt);
|
||||
let g = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let oh = scope.create_local(u32::as_elem());
|
||||
let ow = scope.create_local(u32::as_elem());
|
||||
let k = scope.create_local(u32::as_elem());
|
||||
let g = scope.create_local(u32::as_elem());
|
||||
|
||||
let ic_start = scope.create_local(Elem::UInt);
|
||||
let ic_end = scope.create_local(Elem::UInt);
|
||||
let ic_tmp = scope.create_local(Elem::UInt);
|
||||
let ic_start = scope.create_local(u32::as_elem());
|
||||
let ic_end = scope.create_local(u32::as_elem());
|
||||
let ic_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -141,20 +142,20 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, ic_start = g * ic_tmp);
|
||||
cpa!(scope, ic_end = ic_start + ic_tmp);
|
||||
|
||||
let tmp_u = scope.create_local(Elem::UInt);
|
||||
let tmp_u = scope.create_local(u32::as_elem());
|
||||
let tmp_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let zero_i = scope.zero(Elem::Int(IntKind::I32));
|
||||
let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32));
|
||||
|
||||
let kms_u = scope.create_local(Elem::UInt);
|
||||
let kms_u = scope.create_local(u32::as_elem());
|
||||
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
let iw_end = scope.create_local(Elem::UInt);
|
||||
let ih_start = scope.create_local(u32::as_elem());
|
||||
let iw_start = scope.create_local(u32::as_elem());
|
||||
let ih_end = scope.create_local(u32::as_elem());
|
||||
let iw_end = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, kms_u = kernel_size_0 * dilation_0);
|
||||
cpa!(scope, kms_0 = cast(kms_u));
|
||||
|
@ -188,17 +189,17 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, tmp_u = cast(tmp_i));
|
||||
cpa!(scope, iw_end = min(tmp_u, input_shape_3));
|
||||
|
||||
let index_input = scope.create_local(Elem::UInt);
|
||||
let index_weight = scope.create_local(Elem::UInt);
|
||||
let index_input = scope.create_local(u32::as_elem());
|
||||
let index_weight = scope.create_local(u32::as_elem());
|
||||
|
||||
let index_input_b = scope.create_local(Elem::UInt);
|
||||
let index_input_ic = scope.create_local(Elem::UInt);
|
||||
let index_input_ih = scope.create_local(Elem::UInt);
|
||||
let index_input_iw = scope.create_local(Elem::UInt);
|
||||
let index_weight_ic = scope.create_local(Elem::UInt);
|
||||
let index_weight_oc = scope.create_local(Elem::UInt);
|
||||
let index_weight_kh = scope.create_local(Elem::UInt);
|
||||
let index_weight_kw = scope.create_local(Elem::UInt);
|
||||
let index_input_b = scope.create_local(u32::as_elem());
|
||||
let index_input_ic = scope.create_local(u32::as_elem());
|
||||
let index_input_ih = scope.create_local(u32::as_elem());
|
||||
let index_input_iw = scope.create_local(u32::as_elem());
|
||||
let index_weight_ic = scope.create_local(u32::as_elem());
|
||||
let index_weight_oc = scope.create_local(u32::as_elem());
|
||||
let index_weight_kh = scope.create_local(u32::as_elem());
|
||||
let index_weight_kw = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_input_b = b * input_stride_0);
|
||||
cpa!(scope, index_weight_oc = oc * weight_stride_1);
|
||||
|
@ -208,15 +209,15 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
let sum = scope.create_local(output.item);
|
||||
cpa!(scope, sum = bias[oc_out]);
|
||||
|
||||
let kh = scope.create_local(Elem::UInt);
|
||||
let kw = scope.create_local(Elem::UInt);
|
||||
let numerator_h_base = scope.create_local(Elem::UInt);
|
||||
let numerator_h = scope.create_local(Elem::UInt);
|
||||
let numerator_w_base = scope.create_local(Elem::UInt);
|
||||
let numerator_w = scope.create_local(Elem::UInt);
|
||||
let numerator_tmp = scope.create_local(Elem::UInt);
|
||||
let numerator_mod = scope.create_local(Elem::UInt);
|
||||
let zero = scope.zero(Elem::UInt);
|
||||
let kh = scope.create_local(u32::as_elem());
|
||||
let kw = scope.create_local(u32::as_elem());
|
||||
let numerator_h_base = scope.create_local(u32::as_elem());
|
||||
let numerator_h = scope.create_local(u32::as_elem());
|
||||
let numerator_w_base = scope.create_local(u32::as_elem());
|
||||
let numerator_w = scope.create_local(u32::as_elem());
|
||||
let numerator_tmp = scope.create_local(u32::as_elem());
|
||||
let numerator_mod = scope.create_local(u32::as_elem());
|
||||
let zero = scope.zero(u32::as_elem());
|
||||
let divisible = scope.create_local(Elem::Bool);
|
||||
let not_neg = scope.create_local(Elem::Bool);
|
||||
let cond = scope.create_local(Elem::Bool);
|
||||
|
@ -324,7 +325,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: 7,
|
||||
};
|
||||
|
||||
|
|
|
@ -9,10 +9,7 @@ use cubecl::{
|
|||
|
||||
use crate::{
|
||||
kernel::{
|
||||
conv::{
|
||||
batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col,
|
||||
conv2d_implicit_gemm,
|
||||
},
|
||||
conv::{batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col},
|
||||
prng::random_uniform,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
|
@ -42,7 +39,7 @@ pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
|
|||
}
|
||||
|
||||
#[tune(
|
||||
operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm),
|
||||
operations(conv2d_direct, conv2d_im2col),
|
||||
create_key = create_key,
|
||||
should_run = should_run
|
||||
)]
|
||||
|
@ -111,7 +108,7 @@ fn should_run<R: JitRuntime, F: FloatElement, I: IntElement>(
|
|||
op.options.groups,
|
||||
out_h,
|
||||
out_w,
|
||||
&op.input.device,
|
||||
&op.input.client,
|
||||
),
|
||||
_ => true,
|
||||
}
|
||||
|
@ -143,5 +140,6 @@ fn create_key<R: JitRuntime, E: FloatElement>(
|
|||
width,
|
||||
batch_size,
|
||||
bias.is_some(),
|
||||
E::dtype(),
|
||||
))
|
||||
}
|
||||
|
|
|
@ -91,6 +91,7 @@ fn create_key<R: JitRuntime, E: FloatElement>(
|
|||
width,
|
||||
batch_size,
|
||||
bias.is_some(),
|
||||
E::dtype(),
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use burn_tensor::DType;
|
||||
use cubecl::AutotuneKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -20,6 +21,7 @@ pub struct Conv2dAutotuneKey {
|
|||
#[autotune(anchor)]
|
||||
pub batch_size: usize,
|
||||
pub has_bias: bool,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
|
@ -42,4 +44,5 @@ pub struct ConvTranspose2dAutotuneKey {
|
|||
#[autotune(anchor)]
|
||||
pub batch_size: usize,
|
||||
pub has_bias: bool,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use cubecl::{
|
|||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -43,16 +44,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
let output = self.output;
|
||||
let idx = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
let input_stride_4 = scope.create_local(Elem::UInt);
|
||||
let input_shape_0 = scope.create_local(Elem::UInt);
|
||||
let input_shape_1 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
let input_shape_4 = scope.create_local(Elem::UInt);
|
||||
let input_stride_0 = scope.create_local(u32::as_elem());
|
||||
let input_stride_1 = scope.create_local(u32::as_elem());
|
||||
let input_stride_2 = scope.create_local(u32::as_elem());
|
||||
let input_stride_3 = scope.create_local(u32::as_elem());
|
||||
let input_stride_4 = scope.create_local(u32::as_elem());
|
||||
let input_shape_0 = scope.create_local(u32::as_elem());
|
||||
let input_shape_1 = scope.create_local(u32::as_elem());
|
||||
let input_shape_2 = scope.create_local(u32::as_elem());
|
||||
let input_shape_3 = scope.create_local(u32::as_elem());
|
||||
let input_shape_4 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, input_stride_0 = stride(input, 0u32));
|
||||
cpa!(scope, input_stride_1 = stride(input, 1u32));
|
||||
cpa!(scope, input_stride_2 = stride(input, 2u32));
|
||||
|
@ -64,16 +65,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, input_shape_3 = shape(input, 3u32));
|
||||
cpa!(scope, input_shape_4 = shape(input, 4u32));
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_4 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_4 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
let output_stride_4 = scope.create_local(u32::as_elem());
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
let output_shape_4 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, output_stride_0 = stride(output, 0u32));
|
||||
cpa!(scope, output_stride_1 = stride(output, 1u32));
|
||||
cpa!(scope, output_stride_2 = stride(output, 2u32));
|
||||
|
@ -85,16 +86,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
cpa!(scope, output_shape_4 = shape(output, 4u32));
|
||||
|
||||
let weight_stride_0 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_1 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_2 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_3 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_4 = scope.create_local(Elem::UInt);
|
||||
let in_channels = scope.create_local(Elem::UInt);
|
||||
let weight_shape_1 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_0 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_1 = scope.create_local(Elem::UInt);
|
||||
let kernel_size_2 = scope.create_local(Elem::UInt);
|
||||
let weight_stride_0 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_1 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_2 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_3 = scope.create_local(u32::as_elem());
|
||||
let weight_stride_4 = scope.create_local(u32::as_elem());
|
||||
let in_channels = scope.create_local(u32::as_elem());
|
||||
let weight_shape_1 = scope.create_local(u32::as_elem());
|
||||
let kernel_size_0 = scope.create_local(u32::as_elem());
|
||||
let kernel_size_1 = scope.create_local(u32::as_elem());
|
||||
let kernel_size_2 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, weight_stride_0 = stride(weight, 0u32));
|
||||
cpa!(scope, weight_stride_1 = stride(weight, 1u32));
|
||||
cpa!(scope, weight_stride_2 = stride(weight, 2u32));
|
||||
|
@ -106,16 +107,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
|
||||
|
||||
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(Elem::UInt));
|
||||
let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(Elem::UInt));
|
||||
let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(Elem::UInt));
|
||||
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
|
||||
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
|
||||
let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
|
||||
let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(u32::as_elem()));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(u32::as_elem()));
|
||||
let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(u32::as_elem()));
|
||||
let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(u32::as_elem()));
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
@ -124,19 +125,19 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, stride_1_i = cast(conv_stride_1));
|
||||
cpa!(scope, stride_2_i = cast(conv_stride_2));
|
||||
|
||||
let oc_out = scope.create_local(Elem::UInt);
|
||||
let oc = scope.create_local(Elem::UInt);
|
||||
let oc_out = scope.create_local(u32::as_elem());
|
||||
let oc = scope.create_local(u32::as_elem());
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let od = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
let k = scope.create_local(Elem::UInt);
|
||||
let g = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let od = scope.create_local(u32::as_elem());
|
||||
let oh = scope.create_local(u32::as_elem());
|
||||
let ow = scope.create_local(u32::as_elem());
|
||||
let k = scope.create_local(u32::as_elem());
|
||||
let g = scope.create_local(u32::as_elem());
|
||||
|
||||
let ic_start = scope.create_local(Elem::UInt);
|
||||
let ic_end = scope.create_local(Elem::UInt);
|
||||
let ic_tmp = scope.create_local(Elem::UInt);
|
||||
let ic_start = scope.create_local(u32::as_elem());
|
||||
let ic_end = scope.create_local(u32::as_elem());
|
||||
let ic_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = idx / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -162,24 +163,24 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, ic_start = g * ic_tmp);
|
||||
cpa!(scope, ic_end = ic_start + ic_tmp);
|
||||
|
||||
let tmp_u = scope.create_local(Elem::UInt);
|
||||
let tmp_u = scope.create_local(u32::as_elem());
|
||||
let tmp_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let zero_i = scope.zero(Elem::Int(IntKind::I32));
|
||||
let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32));
|
||||
|
||||
let kms_u = scope.create_local(Elem::UInt);
|
||||
let kms_u = scope.create_local(u32::as_elem());
|
||||
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let kms_2 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let id_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let id_start = scope.create_local(Elem::UInt);
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let id_end = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
let iw_end = scope.create_local(Elem::UInt);
|
||||
let id_start = scope.create_local(u32::as_elem());
|
||||
let ih_start = scope.create_local(u32::as_elem());
|
||||
let iw_start = scope.create_local(u32::as_elem());
|
||||
let id_end = scope.create_local(u32::as_elem());
|
||||
let ih_end = scope.create_local(u32::as_elem());
|
||||
let iw_end = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, kms_u = kernel_size_0 * dilation_0);
|
||||
cpa!(scope, kms_0 = cast(kms_u));
|
||||
|
@ -228,19 +229,19 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, tmp_u = cast(tmp_i));
|
||||
cpa!(scope, iw_end = min(tmp_u, input_shape_4));
|
||||
|
||||
let index_input = scope.create_local(Elem::UInt);
|
||||
let index_weight = scope.create_local(Elem::UInt);
|
||||
let index_input = scope.create_local(u32::as_elem());
|
||||
let index_weight = scope.create_local(u32::as_elem());
|
||||
|
||||
let index_input_b = scope.create_local(Elem::UInt);
|
||||
let index_input_ic = scope.create_local(Elem::UInt);
|
||||
let index_input_id = scope.create_local(Elem::UInt);
|
||||
let index_input_ih = scope.create_local(Elem::UInt);
|
||||
let index_input_iw = scope.create_local(Elem::UInt);
|
||||
let index_weight_ic = scope.create_local(Elem::UInt);
|
||||
let index_weight_oc = scope.create_local(Elem::UInt);
|
||||
let index_weight_kd = scope.create_local(Elem::UInt);
|
||||
let index_weight_kh = scope.create_local(Elem::UInt);
|
||||
let index_weight_kw = scope.create_local(Elem::UInt);
|
||||
let index_input_b = scope.create_local(u32::as_elem());
|
||||
let index_input_ic = scope.create_local(u32::as_elem());
|
||||
let index_input_id = scope.create_local(u32::as_elem());
|
||||
let index_input_ih = scope.create_local(u32::as_elem());
|
||||
let index_input_iw = scope.create_local(u32::as_elem());
|
||||
let index_weight_ic = scope.create_local(u32::as_elem());
|
||||
let index_weight_oc = scope.create_local(u32::as_elem());
|
||||
let index_weight_kd = scope.create_local(u32::as_elem());
|
||||
let index_weight_kh = scope.create_local(u32::as_elem());
|
||||
let index_weight_kw = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_input_b = b * input_stride_0);
|
||||
cpa!(scope, index_weight_oc = oc * weight_stride_1);
|
||||
|
@ -250,18 +251,18 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
let sum = scope.create_local(output.item);
|
||||
cpa!(scope, sum = bias[oc_out]);
|
||||
|
||||
let kd = scope.create_local(Elem::UInt);
|
||||
let kh = scope.create_local(Elem::UInt);
|
||||
let kw = scope.create_local(Elem::UInt);
|
||||
let numerator_d_base = scope.create_local(Elem::UInt);
|
||||
let numerator_d = scope.create_local(Elem::UInt);
|
||||
let numerator_h_base = scope.create_local(Elem::UInt);
|
||||
let numerator_h = scope.create_local(Elem::UInt);
|
||||
let numerator_w_base = scope.create_local(Elem::UInt);
|
||||
let numerator_w = scope.create_local(Elem::UInt);
|
||||
let numerator_tmp = scope.create_local(Elem::UInt);
|
||||
let numerator_mod = scope.create_local(Elem::UInt);
|
||||
let zero = scope.zero(Elem::UInt);
|
||||
let kd = scope.create_local(u32::as_elem());
|
||||
let kh = scope.create_local(u32::as_elem());
|
||||
let kw = scope.create_local(u32::as_elem());
|
||||
let numerator_d_base = scope.create_local(u32::as_elem());
|
||||
let numerator_d = scope.create_local(u32::as_elem());
|
||||
let numerator_h_base = scope.create_local(u32::as_elem());
|
||||
let numerator_h = scope.create_local(u32::as_elem());
|
||||
let numerator_w_base = scope.create_local(u32::as_elem());
|
||||
let numerator_w = scope.create_local(u32::as_elem());
|
||||
let numerator_tmp = scope.create_local(u32::as_elem());
|
||||
let numerator_mod = scope.create_local(u32::as_elem());
|
||||
let zero = scope.zero(u32::as_elem());
|
||||
let divisible = scope.create_local(Elem::Bool);
|
||||
let not_neg = scope.create_local(Elem::Bool);
|
||||
let cond = scope.create_local(Elem::Bool);
|
||||
|
@ -392,7 +393,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: 10,
|
||||
};
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ use burn_tensor::{
|
|||
use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch};
|
||||
|
||||
use crate::{
|
||||
kernel::into_contiguous,
|
||||
kernel::{cast, into_contiguous},
|
||||
ops::{
|
||||
numeric::{empty_device, ones_device, zeros_device},
|
||||
reshape, swap_dims,
|
||||
|
@ -426,7 +426,8 @@ fn compute_input_grad<R: JitRuntime, E: FloatElement>(
|
|||
let [batch_size, in_channels, height, width] = input_shape.dims();
|
||||
let (kernel_height, kernel_width) = kernel_dims;
|
||||
|
||||
let grad_in = zeros_device::<R, E>(
|
||||
// Force `f32` to enable bitcasting as `u32`
|
||||
let grad_in = zeros_device::<R, f32>(
|
||||
client.clone(),
|
||||
device.clone(),
|
||||
Shape::new([batch_size, in_channels, height, width]),
|
||||
|
@ -466,7 +467,7 @@ fn compute_input_grad<R: JitRuntime, E: FloatElement>(
|
|||
use_mask,
|
||||
);
|
||||
|
||||
grad_in
|
||||
cast(grad_in)
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
|
@ -564,19 +565,19 @@ fn deform_col2img_kernel<F: Float>(
|
|||
|
||||
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
|
||||
|
||||
float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
|
||||
float_atomic_add(&mut grad_input[gradient_pos], f32::cast_from(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn float_atomic_add<F: Float>(ptr: &mut AtomicU32, value: F) {
|
||||
if value != F::new(0.0) {
|
||||
fn float_atomic_add(ptr: &mut AtomicU32, value: f32) {
|
||||
if value != 0.0 {
|
||||
let mut v = AtomicU32::load(ptr);
|
||||
loop {
|
||||
let prev = v;
|
||||
let v_float = F::bitcast_from(v);
|
||||
let v_float = f32::bitcast_from(v);
|
||||
let new = u32::bitcast_from(v_float + value);
|
||||
v = AtomicU32::compare_and_swap(ptr, v, new);
|
||||
if prev == v {
|
||||
|
|
|
@ -5,6 +5,7 @@ use burn_tensor::ElementConversion;
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -29,12 +30,12 @@ impl FlipComputeShader {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
let offset_input = scope.zero(u32::as_elem());
|
||||
let offset_local = scope.create_local(u32::as_elem());
|
||||
|
||||
let stride = scope.create_local(Elem::UInt);
|
||||
let shape = scope.create_local(Elem::UInt);
|
||||
let flip = scope.create_local(Elem::UInt);
|
||||
let stride = scope.create_local(u32::as_elem());
|
||||
let shape = scope.create_local(u32::as_elem());
|
||||
let flip = scope.create_local(u32::as_elem());
|
||||
let flip_bool = scope.create_local(Elem::Bool);
|
||||
|
||||
for i in 0..self.rank {
|
||||
|
@ -44,7 +45,7 @@ impl FlipComputeShader {
|
|||
scope,
|
||||
flip = cast(Variable::new(
|
||||
VariableKind::GlobalScalar(i as u16),
|
||||
Item::new(Elem::UInt)
|
||||
Item::new(u32::as_elem())
|
||||
))
|
||||
);
|
||||
cpa!(scope, flip_bool = flip == 1u32);
|
||||
|
@ -89,7 +90,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let flip_dims = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: self.rank,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::CubePrimitive,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -28,12 +29,12 @@ impl RepeatComputeShader {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.zero(Elem::UInt);
|
||||
let offset_input = scope.zero(u32::as_elem());
|
||||
let offset_local = scope.zero(u32::as_elem());
|
||||
|
||||
let stride_input = scope.create_local(Elem::UInt);
|
||||
let stride_output = scope.create_local(Elem::UInt);
|
||||
let shape = scope.create_local(Elem::UInt);
|
||||
let stride_input = scope.create_local(u32::as_elem());
|
||||
let stride_output = scope.create_local(u32::as_elem());
|
||||
let shape = scope.create_local(u32::as_elem());
|
||||
|
||||
for i in 0..self.rank {
|
||||
cpa!(scope, stride_input = stride(input, i));
|
||||
|
|
|
@ -2,15 +2,15 @@ use crate::{
|
|||
element::JitElement,
|
||||
kernel::{self},
|
||||
tensor::JitTensor,
|
||||
JitRuntime,
|
||||
IntElement, JitRuntime,
|
||||
};
|
||||
use cubecl::prelude::*;
|
||||
use cubecl::{calculate_cube_count_elemwise, CubeDim};
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn scatter_kernel<T: Numeric>(
|
||||
fn scatter_kernel<T: Numeric, I: Int>(
|
||||
input: &mut Tensor<T>,
|
||||
indices: &Tensor<i32>,
|
||||
indices: &Tensor<I>,
|
||||
value: &Tensor<T>,
|
||||
dim: &u32,
|
||||
) {
|
||||
|
@ -65,7 +65,7 @@ fn scatter_kernel<T: Numeric>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement>(
|
||||
pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: IntElement>(
|
||||
dim: usize,
|
||||
tensor: JitTensor<R, E>,
|
||||
indices: JitTensor<R, I>,
|
||||
|
@ -105,7 +105,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement>(
|
|||
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||
|
||||
unsafe {
|
||||
scatter_kernel::launch_unchecked::<E, R>(
|
||||
scatter_kernel::launch_unchecked::<E, I, R>(
|
||||
&indices.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -4,7 +4,8 @@ use crate::{
|
|||
use burn_tensor::{ElementConversion, Shape};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
ir::{Builtin, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -29,13 +30,13 @@ impl SliceComputeShader {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
let offset_input = scope.zero(u32::as_elem());
|
||||
let offset_local = scope.create_local(u32::as_elem());
|
||||
|
||||
let stride_input = scope.create_local(Elem::UInt);
|
||||
let stride_output = scope.create_local(Elem::UInt);
|
||||
let shape_output = scope.create_local(Elem::UInt);
|
||||
let range_start = scope.create_local(Elem::UInt);
|
||||
let stride_input = scope.create_local(u32::as_elem());
|
||||
let stride_output = scope.create_local(u32::as_elem());
|
||||
let shape_output = scope.create_local(u32::as_elem());
|
||||
let range_start = scope.create_local(u32::as_elem());
|
||||
|
||||
for i in 0..self.rank {
|
||||
cpa!(scope, stride_input = stride(input, i));
|
||||
|
@ -45,7 +46,7 @@ impl SliceComputeShader {
|
|||
scope,
|
||||
range_start = cast(Variable::new(
|
||||
VariableKind::GlobalScalar(i as u16),
|
||||
Item::new(Elem::UInt)
|
||||
Item::new(u32::as_elem())
|
||||
))
|
||||
);
|
||||
|
||||
|
@ -85,7 +86,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let ranges = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: self.rank,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
|
|
@ -2,7 +2,8 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
|
|||
use burn_tensor::ElementConversion;
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
ir::{Builtin, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
};
|
||||
use std::{marker::PhantomData, ops::Range};
|
||||
|
@ -26,18 +27,18 @@ impl SliceAssignComputeShader {
|
|||
let value = self.value;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_value = scope.zero(Elem::UInt);
|
||||
let offset_input = scope.zero(u32::as_elem());
|
||||
let offset_value = scope.zero(u32::as_elem());
|
||||
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
let offset_local_value = scope.create_local(Elem::UInt);
|
||||
let offset_local_input = scope.create_local(Elem::UInt);
|
||||
let offset_local = scope.create_local(u32::as_elem());
|
||||
let offset_local_value = scope.create_local(u32::as_elem());
|
||||
let offset_local_input = scope.create_local(u32::as_elem());
|
||||
|
||||
let stride_input = scope.create_local(Elem::UInt);
|
||||
let stride_value = scope.create_local(Elem::UInt);
|
||||
let shape_value = scope.create_local(Elem::UInt);
|
||||
let shape_input = scope.create_local(Elem::UInt);
|
||||
let range_start = scope.create_local(Elem::UInt);
|
||||
let stride_input = scope.create_local(u32::as_elem());
|
||||
let stride_value = scope.create_local(u32::as_elem());
|
||||
let shape_value = scope.create_local(u32::as_elem());
|
||||
let shape_input = scope.create_local(u32::as_elem());
|
||||
let range_start = scope.create_local(u32::as_elem());
|
||||
|
||||
for i in 0..self.rank {
|
||||
cpa!(scope, stride_input = stride(input, i));
|
||||
|
@ -48,7 +49,7 @@ impl SliceAssignComputeShader {
|
|||
scope,
|
||||
range_start = cast(Variable::new(
|
||||
VariableKind::GlobalScalar(i as u16),
|
||||
Item::new(Elem::UInt)
|
||||
Item::new(u32::as_elem())
|
||||
))
|
||||
);
|
||||
|
||||
|
@ -98,7 +99,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let ranges = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: self.rank,
|
||||
};
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -27,23 +28,23 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
let elem = E::cube_elem();
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
let input_stride_0 = scope.create_local(u32::as_elem());
|
||||
let input_stride_1 = scope.create_local(u32::as_elem());
|
||||
let input_stride_2 = scope.create_local(u32::as_elem());
|
||||
let input_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(u32::as_elem());
|
||||
let input_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, input_stride_0 = stride(input, 0u32));
|
||||
cpa!(scope, input_stride_1 = stride(input, 1u32));
|
||||
|
@ -63,10 +64,10 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let h = scope.create_local(Elem::UInt);
|
||||
let w = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let c = scope.create_local(u32::as_elem());
|
||||
let h = scope.create_local(u32::as_elem());
|
||||
let w = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -80,23 +81,23 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
cpa!(scope, w = id / output_stride_3);
|
||||
cpa!(scope, w = w % output_shape_3);
|
||||
|
||||
let input_height = scope.create_local(Elem::UInt);
|
||||
let output_height = scope.create_local(Elem::UInt);
|
||||
let input_height = scope.create_local(u32::as_elem());
|
||||
let output_height = scope.create_local(u32::as_elem());
|
||||
let output_height_float = scope.create_local(elem);
|
||||
|
||||
let input_width = scope.create_local(Elem::UInt);
|
||||
let output_width = scope.create_local(Elem::UInt);
|
||||
let input_width = scope.create_local(u32::as_elem());
|
||||
let output_width = scope.create_local(u32::as_elem());
|
||||
let output_width_float = scope.create_local(elem);
|
||||
|
||||
let frac = scope.create_local(elem);
|
||||
let numerator = scope.create_local(Elem::UInt);
|
||||
let numerator = scope.create_local(u32::as_elem());
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let not_zero = scope.create_local(Elem::Bool);
|
||||
|
||||
let y_in_float = scope.create_local(elem);
|
||||
let y_in = scope.create_local(Elem::UInt);
|
||||
let y_in = scope.create_local(u32::as_elem());
|
||||
let yw = scope.create_local(elem);
|
||||
let y_tmp = scope.create_local(Elem::UInt);
|
||||
let y_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, input_height = input_shape_2 - 1u32);
|
||||
cpa!(scope, output_height = output_shape_2 - 1u32);
|
||||
|
@ -109,7 +110,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
cpa!(scope, y_in = cast(y_in_float));
|
||||
cpa!(scope, yw = frac - y_in_float);
|
||||
|
||||
let y0 = scope.zero(Elem::UInt);
|
||||
let y0 = scope.zero(u32::as_elem());
|
||||
cpa!(scope, not_zero = y_in != 0u32);
|
||||
cpa!(scope, if(not_zero).then(|scope|{
|
||||
cpa!(scope, y0 = y_in - 1u32);
|
||||
|
@ -124,9 +125,9 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
let y3 = Self::min(scope, y_tmp, input_height);
|
||||
|
||||
let x_in_float = scope.create_local(elem);
|
||||
let x_in = scope.create_local(Elem::UInt);
|
||||
let x_in = scope.create_local(u32::as_elem());
|
||||
let xw = scope.create_local(elem);
|
||||
let x_tmp = scope.create_local(Elem::UInt);
|
||||
let x_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, input_width = input_shape_3 - 1u32);
|
||||
cpa!(scope, output_width = output_shape_3 - 1u32);
|
||||
|
@ -139,7 +140,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
cpa!(scope, x_in = cast(x_in_float));
|
||||
cpa!(scope, xw = frac - x_in_float);
|
||||
|
||||
let x0 = scope.zero(Elem::UInt);
|
||||
let x0 = scope.zero(u32::as_elem());
|
||||
cpa!(scope, not_zero = x_in != 0u32);
|
||||
cpa!(scope, if(not_zero).then(|scope|{
|
||||
cpa!(scope, x0 = x_in - 1u32);
|
||||
|
@ -154,20 +155,20 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
cpa!(scope, x_tmp = x_in + 2u32);
|
||||
let x3 = Self::min(scope, x_tmp, input_width);
|
||||
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(u32::as_elem());
|
||||
let index_tmp = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, index_base = b * input_stride_0);
|
||||
cpa!(scope, index_tmp = c * input_stride_1);
|
||||
cpa!(scope, index_base += index_tmp);
|
||||
|
||||
let y0_stride = scope.create_local(Elem::UInt);
|
||||
let y1_stride = scope.create_local(Elem::UInt);
|
||||
let y2_stride = scope.create_local(Elem::UInt);
|
||||
let y3_stride = scope.create_local(Elem::UInt);
|
||||
let x0_stride = scope.create_local(Elem::UInt);
|
||||
let x1_stride = scope.create_local(Elem::UInt);
|
||||
let x2_stride = scope.create_local(Elem::UInt);
|
||||
let x3_stride = scope.create_local(Elem::UInt);
|
||||
let y0_stride = scope.create_local(u32::as_elem());
|
||||
let y1_stride = scope.create_local(u32::as_elem());
|
||||
let y2_stride = scope.create_local(u32::as_elem());
|
||||
let y3_stride = scope.create_local(u32::as_elem());
|
||||
let x0_stride = scope.create_local(u32::as_elem());
|
||||
let x1_stride = scope.create_local(u32::as_elem());
|
||||
let x2_stride = scope.create_local(u32::as_elem());
|
||||
let x3_stride = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, y0_stride = y0 * input_stride_2);
|
||||
cpa!(scope, y1_stride = y1 * input_stride_2);
|
||||
cpa!(scope, y2_stride = y2 * input_stride_2);
|
||||
|
@ -177,10 +178,10 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
cpa!(scope, x2_stride = x2 * input_stride_3);
|
||||
cpa!(scope, x3_stride = x3 * input_stride_3);
|
||||
|
||||
let index_0 = scope.create_local(Elem::UInt);
|
||||
let index_1 = scope.create_local(Elem::UInt);
|
||||
let index_2 = scope.create_local(Elem::UInt);
|
||||
let index_3 = scope.create_local(Elem::UInt);
|
||||
let index_0 = scope.create_local(u32::as_elem());
|
||||
let index_1 = scope.create_local(u32::as_elem());
|
||||
let index_2 = scope.create_local(u32::as_elem());
|
||||
let index_3 = scope.create_local(u32::as_elem());
|
||||
let inp_0 = scope.create_local(input.item);
|
||||
let inp_1 = scope.create_local(input.item);
|
||||
let inp_2 = scope.create_local(input.item);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -25,23 +26,23 @@ impl InterpolateBilinearShader {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
let input_stride_0 = scope.create_local(u32::as_elem());
|
||||
let input_stride_1 = scope.create_local(u32::as_elem());
|
||||
let input_stride_2 = scope.create_local(u32::as_elem());
|
||||
let input_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(u32::as_elem());
|
||||
let input_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, input_stride_0 = stride(input, 0u32));
|
||||
cpa!(scope, input_stride_1 = stride(input, 1u32));
|
||||
|
@ -61,10 +62,10 @@ impl InterpolateBilinearShader {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let h = scope.create_local(Elem::UInt);
|
||||
let w = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let c = scope.create_local(u32::as_elem());
|
||||
let h = scope.create_local(u32::as_elem());
|
||||
let w = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -80,22 +81,22 @@ impl InterpolateBilinearShader {
|
|||
|
||||
let factor_float = scope.create_local(input.item);
|
||||
let numerator_float = scope.create_local(input.item);
|
||||
let numerator_int = scope.create_local(Elem::UInt);
|
||||
let numerator_int = scope.create_local(u32::as_elem());
|
||||
let denominator_float = scope.create_local(input.item);
|
||||
let denominator_int = scope.create_local(Elem::UInt);
|
||||
let denominator_int = scope.create_local(u32::as_elem());
|
||||
|
||||
let frac = scope.create_local(input.item);
|
||||
let v0 = scope.create_local(input.item);
|
||||
let v1 = scope.create_local(input.item);
|
||||
let one = scope.create_with_value(1f32, input.item);
|
||||
|
||||
let y0 = scope.create_local(Elem::UInt);
|
||||
let y1 = scope.create_local(Elem::UInt);
|
||||
let y0 = scope.create_local(u32::as_elem());
|
||||
let y1 = scope.create_local(u32::as_elem());
|
||||
let yw = scope.create_local(input.item);
|
||||
let yw_ = scope.create_local(input.item);
|
||||
|
||||
let x0 = scope.create_local(Elem::UInt);
|
||||
let x1 = scope.create_local(Elem::UInt);
|
||||
let x0 = scope.create_local(u32::as_elem());
|
||||
let x1 = scope.create_local(u32::as_elem());
|
||||
let xw = scope.create_local(input.item);
|
||||
let xw_ = scope.create_local(input.item);
|
||||
|
||||
|
@ -129,13 +130,13 @@ impl InterpolateBilinearShader {
|
|||
cpa!(scope, x0 = cast(v0));
|
||||
cpa!(scope, x1 = cast(v1));
|
||||
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let y0_stride = scope.create_local(Elem::UInt);
|
||||
let y1_stride = scope.create_local(Elem::UInt);
|
||||
let x0_stride = scope.create_local(Elem::UInt);
|
||||
let x1_stride = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(u32::as_elem());
|
||||
let index_tmp = scope.create_local(u32::as_elem());
|
||||
let index = scope.create_local(u32::as_elem());
|
||||
let y0_stride = scope.create_local(u32::as_elem());
|
||||
let y1_stride = scope.create_local(u32::as_elem());
|
||||
let x0_stride = scope.create_local(u32::as_elem());
|
||||
let x1_stride = scope.create_local(u32::as_elem());
|
||||
let p_a = scope.create_local(input.item);
|
||||
let p_b = scope.create_local(input.item);
|
||||
let p_c = scope.create_local(input.item);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
ir::{Builtin, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -27,23 +28,23 @@ impl<E: JitElement> InterpolateNearestShader<E> {
|
|||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
let elem = E::cube_elem();
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
let input_stride_2 = scope.create_local(Elem::UInt);
|
||||
let input_stride_3 = scope.create_local(Elem::UInt);
|
||||
let input_stride_0 = scope.create_local(u32::as_elem());
|
||||
let input_stride_1 = scope.create_local(u32::as_elem());
|
||||
let input_stride_2 = scope.create_local(u32::as_elem());
|
||||
let input_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let input_shape_2 = scope.create_local(Elem::UInt);
|
||||
let input_shape_3 = scope.create_local(Elem::UInt);
|
||||
let input_shape_2 = scope.create_local(u32::as_elem());
|
||||
let input_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, input_stride_0 = stride(input, 0u32));
|
||||
cpa!(scope, input_stride_1 = stride(input, 1u32));
|
||||
|
@ -63,10 +64,10 @@ impl<E: JitElement> InterpolateNearestShader<E> {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let h = scope.create_local(Elem::UInt);
|
||||
let w = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let c = scope.create_local(u32::as_elem());
|
||||
let h = scope.create_local(u32::as_elem());
|
||||
let w = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -85,8 +86,8 @@ impl<E: JitElement> InterpolateNearestShader<E> {
|
|||
let denominator_float = scope.create_local(elem);
|
||||
let x = scope.create_local(elem);
|
||||
let y = scope.create_local(elem);
|
||||
let xu = scope.create_local(Elem::UInt);
|
||||
let yu = scope.create_local(Elem::UInt);
|
||||
let xu = scope.create_local(u32::as_elem());
|
||||
let yu = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, factor_float = cast(h));
|
||||
cpa!(scope, numerator_float = cast(input_shape_2));
|
||||
|
@ -104,8 +105,8 @@ impl<E: JitElement> InterpolateNearestShader<E> {
|
|||
cpa!(scope, x = floor(x));
|
||||
cpa!(scope, xu = cast(x));
|
||||
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(u32::as_elem());
|
||||
let index_tmp = scope.create_local(u32::as_elem());
|
||||
let val = scope.create_local(output.item);
|
||||
|
||||
cpa!(scope, index = b * input_stride_0);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -26,25 +27,25 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_2 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_3 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_0 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_1 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_2 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let grad_shape_0 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_1 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_2 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_3 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_0 = scope.create_local(u32::as_elem());
|
||||
let grad_shape_1 = scope.create_local(u32::as_elem());
|
||||
let grad_shape_2 = scope.create_local(u32::as_elem());
|
||||
let grad_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, grad_stride_0 = stride(grad, 0u32));
|
||||
cpa!(scope, grad_stride_1 = stride(grad, 1u32));
|
||||
|
@ -66,10 +67,10 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let oh = scope.create_local(Elem::UInt);
|
||||
let ow = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let c = scope.create_local(u32::as_elem());
|
||||
let oh = scope.create_local(u32::as_elem());
|
||||
let ow = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -90,11 +91,11 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
|
||||
let result = scope.create_local(grad.item);
|
||||
|
||||
let index_grad = scope.create_local(Elem::UInt);
|
||||
let index_grad_0 = scope.create_local(Elem::UInt);
|
||||
let index_grad_1 = scope.create_local(Elem::UInt);
|
||||
let index_grad_2 = scope.create_local(Elem::UInt);
|
||||
let index_grad_3 = scope.create_local(Elem::UInt);
|
||||
let index_grad = scope.create_local(u32::as_elem());
|
||||
let index_grad_0 = scope.create_local(u32::as_elem());
|
||||
let index_grad_1 = scope.create_local(u32::as_elem());
|
||||
let index_grad_2 = scope.create_local(u32::as_elem());
|
||||
let index_grad_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_grad_0 = b * grad_stride_0);
|
||||
cpa!(scope, index_grad_1 = c * grad_stride_1);
|
||||
|
@ -135,7 +136,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
let elem = E::cube_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index = input_index * output_size);
|
||||
cpa!(scope, numerator_float = cast(index));
|
||||
|
@ -156,9 +157,9 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
let elem = E::cube_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(u32::as_elem());
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
let end_index = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index = input_index + 1u32);
|
||||
cpa!(scope, index *= output_size);
|
||||
|
|
|
@ -23,7 +23,7 @@ pub struct MatmulAutotuneOperationSet<R: JitRuntime, E: FloatElement> {
|
|||
impl<R: JitRuntime, E: FloatElement> MatmulAutotuneOperationSet<R, E> {
|
||||
fn new(lhs: JitTensor<R, E>, rhs: JitTensor<R, E>, out: JitTensor<R, E>) -> Self {
|
||||
Self {
|
||||
key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)),
|
||||
key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())),
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::tune::anchor;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::{DType, Shape};
|
||||
use core::fmt::Debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp::max, fmt::Display, hash::Hash};
|
||||
|
@ -13,19 +13,21 @@ pub struct MatmulAutotuneKey {
|
|||
anchored_k: usize,
|
||||
anchored_n: usize,
|
||||
anchored_batch: usize,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for MatmulAutotuneKey {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}",
|
||||
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?} dtype:{:?}",
|
||||
self.round,
|
||||
self.broadcast,
|
||||
self.anchored_m,
|
||||
self.anchored_k,
|
||||
self.anchored_n,
|
||||
self.anchored_batch
|
||||
self.anchored_batch,
|
||||
self.dtype
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
|
@ -34,7 +36,7 @@ impl Display for MatmulAutotuneKey {
|
|||
|
||||
impl MatmulAutotuneKey {
|
||||
/// Create a matmul autotune key from the input shapes
|
||||
pub fn new(lhs_shape: &Shape, rhs_shape: &Shape) -> Self {
|
||||
pub fn new(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
|
||||
let ndims = lhs_shape.num_dims();
|
||||
let m = lhs_shape.dims[ndims - 2];
|
||||
let k = lhs_shape.dims[ndims - 1];
|
||||
|
@ -62,6 +64,7 @@ impl MatmulAutotuneKey {
|
|||
anchored_k: anchor(k, None),
|
||||
anchored_n: anchor(n, None),
|
||||
anchored_batch: anchor(batch_product, Some(256)),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +77,7 @@ mod tests {
|
|||
fn matmul_autotune_key_all_same_and_round() {
|
||||
let lhs_shape: Shape = [4, 512, 512].into();
|
||||
let rhs_shape: Shape = [4, 512, 512].into();
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
|
||||
|
||||
assert!(key.round);
|
||||
assert!(!key.broadcast);
|
||||
|
@ -87,7 +90,7 @@ mod tests {
|
|||
fn matmul_autotune_key_all_different() {
|
||||
let lhs_shape: Shape = [2, 3, 511, 512].into();
|
||||
let rhs_shape: Shape = [3, 2, 512, 513].into();
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
|
||||
|
||||
assert!(!key.round);
|
||||
assert!(key.broadcast);
|
||||
|
@ -101,7 +104,7 @@ mod tests {
|
|||
fn matmul_autotune_key_large_batch() {
|
||||
let lhs_shape: Shape = [128, 512, 511, 512].into();
|
||||
let rhs_shape: Shape = [200, 400, 512, 513].into();
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
|
||||
|
||||
assert!(key.anchored_batch == 256);
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ use cubecl::{
|
|||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -36,23 +37,23 @@ impl AvgPool2dBackwardComputeShader {
|
|||
let output = self.output;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_2 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_3 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_0 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_1 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_2 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let grad_shape_2 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_3 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_2 = scope.create_local(u32::as_elem());
|
||||
let grad_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, grad_stride_0 = stride(grad, 0u32));
|
||||
cpa!(scope, grad_stride_1 = stride(grad, 1u32));
|
||||
|
@ -72,16 +73,16 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let ih = scope.create_local(Elem::UInt);
|
||||
let iw = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let c = scope.create_local(u32::as_elem());
|
||||
let ih = scope.create_local(u32::as_elem());
|
||||
let iw = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -95,16 +96,16 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, iw = id / output_stride_3);
|
||||
cpa!(scope, iw = iw % output_shape_3);
|
||||
|
||||
let index_current = scope.create_local(Elem::UInt);
|
||||
let index_current_tmp = scope.create_local(Elem::UInt);
|
||||
let index_current = scope.create_local(u32::as_elem());
|
||||
let index_current_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_current = ih * output_stride_2);
|
||||
cpa!(scope, index_current_tmp = iw * output_stride_3);
|
||||
cpa!(scope, index_current += index_current_tmp);
|
||||
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(u32::as_elem());
|
||||
let index_tmp = scope.create_local(u32::as_elem());
|
||||
let index_base = scope.create_local(u32::as_elem());
|
||||
|
||||
let grad_accumulation = scope.zero(grad.item);
|
||||
let result = scope.create_local(grad.item);
|
||||
|
@ -130,14 +131,14 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, index_tmp = c * grad_stride_1);
|
||||
cpa!(scope, index_base += index_tmp);
|
||||
|
||||
let border_bottom = scope.create_local(Elem::UInt);
|
||||
let border_right = scope.create_local(Elem::UInt);
|
||||
let begin_h = scope.create_local(Elem::UInt);
|
||||
let begin_w = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let iw_end = scope.create_local(Elem::UInt);
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
let border_bottom = scope.create_local(u32::as_elem());
|
||||
let border_right = scope.create_local(u32::as_elem());
|
||||
let begin_h = scope.create_local(u32::as_elem());
|
||||
let begin_w = scope.create_local(u32::as_elem());
|
||||
let iw_start = scope.create_local(u32::as_elem());
|
||||
let iw_end = scope.create_local(u32::as_elem());
|
||||
let ih_start = scope.create_local(u32::as_elem());
|
||||
let ih_end = scope.create_local(u32::as_elem());
|
||||
let after_start = scope.create_local(Elem::Bool);
|
||||
let before_end = scope.create_local(Elem::Bool);
|
||||
let contributed_h = scope.create_local(Elem::Bool);
|
||||
|
@ -147,9 +148,9 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, begin_h = ih + padding_0);
|
||||
cpa!(scope, begin_w = iw + padding_1);
|
||||
|
||||
let ih_diff = scope.create_local(Elem::UInt);
|
||||
let iw_diff = scope.create_local(Elem::UInt);
|
||||
let count_int = scope.create_local(Elem::UInt);
|
||||
let ih_diff = scope.create_local(u32::as_elem());
|
||||
let iw_diff = scope.create_local(u32::as_elem());
|
||||
let count_int = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(
|
||||
scope,
|
||||
|
@ -216,12 +217,12 @@ impl AvgPool2dBackwardComputeShader {
|
|||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
|
@ -273,8 +274,8 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
|
||||
cpa!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
|
||||
|
||||
let oh_start = scope.create_local(Elem::UInt);
|
||||
let ow_start = scope.create_local(Elem::UInt);
|
||||
let oh_start = scope.create_local(u32::as_elem());
|
||||
let ow_start = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, oh_start = cast(oh_start_tmp));
|
||||
cpa!(scope, ow_start = cast(ow_start_tmp));
|
||||
|
@ -285,11 +286,11 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, oh_end_tmp = max(kms_0, 0i32));
|
||||
cpa!(scope, ow_end_tmp = max(kms_1, 0i32));
|
||||
|
||||
let oh_end = scope.create_local(Elem::UInt);
|
||||
let ow_end = scope.create_local(Elem::UInt);
|
||||
let oh_end = scope.create_local(u32::as_elem());
|
||||
let ow_end = scope.create_local(u32::as_elem());
|
||||
|
||||
let oh_end_limit = scope.create_local(Elem::UInt);
|
||||
let ow_end_limit = scope.create_local(Elem::UInt);
|
||||
let oh_end_limit = scope.create_local(u32::as_elem());
|
||||
let ow_end_limit = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, oh_end = cast(oh_end_tmp));
|
||||
cpa!(scope, ow_end = cast(ow_end_tmp));
|
||||
|
@ -303,8 +304,8 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, oh_end = min(oh_end, oh_end_limit));
|
||||
cpa!(scope, ow_end = min(ow_end, ow_end_limit));
|
||||
|
||||
let index_current = scope.create_local(Elem::UInt);
|
||||
let index_current_tmp = scope.create_local(Elem::UInt);
|
||||
let index_current = scope.create_local(u32::as_elem());
|
||||
let index_current_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_current = ih * output_stride_2);
|
||||
cpa!(scope, index_current_tmp = iw * output_stride_3);
|
||||
|
@ -340,7 +341,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
|
|
@ -10,6 +10,7 @@ use cubecl::{
|
|||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -36,23 +37,23 @@ impl MaxPool2dBackwardComputeShader {
|
|||
let indices = self.indices;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_2 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_3 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_0 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_1 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_2 = scope.create_local(u32::as_elem());
|
||||
let grad_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let grad_shape_2 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_3 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_2 = scope.create_local(u32::as_elem());
|
||||
let grad_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
let output_stride_0 = scope.create_local(u32::as_elem());
|
||||
let output_stride_1 = scope.create_local(u32::as_elem());
|
||||
let output_stride_2 = scope.create_local(u32::as_elem());
|
||||
let output_stride_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
let output_shape_0 = scope.create_local(u32::as_elem());
|
||||
let output_shape_1 = scope.create_local(u32::as_elem());
|
||||
let output_shape_2 = scope.create_local(u32::as_elem());
|
||||
let output_shape_3 = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, grad_stride_0 = stride(grad, 0u32));
|
||||
cpa!(scope, grad_stride_1 = stride(grad, 1u32));
|
||||
|
@ -72,10 +73,10 @@ impl MaxPool2dBackwardComputeShader {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let ih = scope.create_local(Elem::UInt);
|
||||
let iw = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
let c = scope.create_local(u32::as_elem());
|
||||
let ih = scope.create_local(u32::as_elem());
|
||||
let iw = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, b = id / output_stride_0);
|
||||
cpa!(scope, b = b % output_shape_0);
|
||||
|
@ -89,8 +90,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
cpa!(scope, iw = id / output_stride_3);
|
||||
cpa!(scope, iw = iw % output_shape_3);
|
||||
|
||||
let index_current = scope.create_local(Elem::UInt);
|
||||
let index_current_tmp = scope.create_local(Elem::UInt);
|
||||
let index_current = scope.create_local(u32::as_elem());
|
||||
let index_current_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_current = ih * output_stride_2);
|
||||
cpa!(scope, index_current_tmp = iw * output_stride_3);
|
||||
|
@ -98,12 +99,12 @@ impl MaxPool2dBackwardComputeShader {
|
|||
|
||||
let index_select = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
let index_max = scope.create_local(Elem::UInt);
|
||||
let index_max = scope.create_local(u32::as_elem());
|
||||
let is_max = scope.create_local(Elem::Bool);
|
||||
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(u32::as_elem());
|
||||
let index_base = scope.create_local(u32::as_elem());
|
||||
let index_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
let grad_accumulation = scope.zero(grad.item);
|
||||
let result = scope.create_local(grad.item);
|
||||
|
@ -162,12 +163,12 @@ impl MaxPool2dBackwardComputeShader {
|
|||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(u32::as_elem()));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(u32::as_elem()));
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
|
@ -219,8 +220,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
cpa!(scope, oh_start_tmp = max(oh_start_tmp, 0i32));
|
||||
cpa!(scope, ow_start_tmp = max(ow_start_tmp, 0i32));
|
||||
|
||||
let oh_start = scope.create_local(Elem::UInt);
|
||||
let ow_start = scope.create_local(Elem::UInt);
|
||||
let oh_start = scope.create_local(u32::as_elem());
|
||||
let ow_start = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, oh_start = cast(oh_start_tmp));
|
||||
cpa!(scope, ow_start = cast(ow_start_tmp));
|
||||
|
@ -231,11 +232,11 @@ impl MaxPool2dBackwardComputeShader {
|
|||
cpa!(scope, oh_end_tmp = max(kms_0, 0i32));
|
||||
cpa!(scope, ow_end_tmp = max(kms_1, 0i32));
|
||||
|
||||
let oh_end = scope.create_local(Elem::UInt);
|
||||
let ow_end = scope.create_local(Elem::UInt);
|
||||
let oh_end = scope.create_local(u32::as_elem());
|
||||
let ow_end = scope.create_local(u32::as_elem());
|
||||
|
||||
let oh_end_limit = scope.create_local(Elem::UInt);
|
||||
let ow_end_limit = scope.create_local(Elem::UInt);
|
||||
let oh_end_limit = scope.create_local(u32::as_elem());
|
||||
let ow_end_limit = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, oh_end = cast(oh_end_tmp));
|
||||
cpa!(scope, ow_end = cast(ow_end_tmp));
|
||||
|
@ -249,8 +250,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
cpa!(scope, oh_end = min(oh_end, oh_end_limit));
|
||||
cpa!(scope, ow_end = min(ow_end, ow_end_limit));
|
||||
|
||||
let index_current = scope.create_local(Elem::UInt);
|
||||
let index_current_tmp = scope.create_local(Elem::UInt);
|
||||
let index_current = scope.create_local(u32::as_elem());
|
||||
let index_current_tmp = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, index_current = ih * output_stride_2);
|
||||
cpa!(scope, index_current_tmp = iw * output_stride_3);
|
||||
|
@ -295,7 +296,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
|
|||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Builtin, Elem, Item, Scope, Variable, VariableKind},
|
||||
ir::{Builtin, Item, Scope, Variable, VariableKind},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, OutputInfo,
|
||||
};
|
||||
|
@ -60,10 +60,10 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
|
|||
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
let seed0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let seed1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let seed2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let seed3 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let seed0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(u32::as_elem()));
|
||||
let seed1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(u32::as_elem()));
|
||||
let seed2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(u32::as_elem()));
|
||||
let seed3 = Variable::new(VariableKind::GlobalScalar(3), Item::new(u32::as_elem()));
|
||||
let seeds = [seed0, seed1, seed2, seed3];
|
||||
|
||||
let mut args = Vec::<Variable>::new();
|
||||
|
@ -80,7 +80,7 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
|
|||
size: P::args_length(),
|
||||
};
|
||||
let seeds = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
elem: u32::as_elem(),
|
||||
size: 4,
|
||||
};
|
||||
let out = OutputInfo::Array { item };
|
||||
|
@ -166,40 +166,40 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
|||
let cube_count_y = Variable::builtin(Builtin::CubeCountY);
|
||||
let local_index = Variable::builtin(Builtin::UnitPos);
|
||||
|
||||
let n_invocations = scope.create_local(Elem::UInt);
|
||||
let n_invocations = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, n_invocations = cube_dim_x);
|
||||
cpa!(scope, n_invocations *= cube_dim_y);
|
||||
|
||||
let cube_offset = scope.create_local(Elem::UInt);
|
||||
let cube_offset = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, cube_offset = cube_pos_x * cube_count_y);
|
||||
cpa!(scope, cube_offset += cube_pos_y);
|
||||
cpa!(scope, cube_offset *= n_invocations);
|
||||
|
||||
let write_index_base = scope.create_local(Elem::UInt);
|
||||
let write_index_base = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, write_index_base = cube_offset);
|
||||
cpa!(scope, write_index_base *= n_values_per_thread);
|
||||
cpa!(scope, write_index_base += local_index);
|
||||
|
||||
// Set state with unique seeds
|
||||
let thread_seed = scope.create_local(Elem::UInt);
|
||||
let thread_seed = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, thread_seed = cast(1000000007));
|
||||
let thread_seed_index = scope.create_local(Elem::UInt);
|
||||
let thread_seed_index = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, thread_seed_index = cube_offset + local_index);
|
||||
cpa!(scope, thread_seed *= thread_seed_index);
|
||||
|
||||
let state_0 = scope.create_local(Elem::UInt);
|
||||
let state_0 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, state_0 = thread_seed);
|
||||
cpa!(scope, state_0 += seed_0);
|
||||
|
||||
let state_1 = scope.create_local(Elem::UInt);
|
||||
let state_1 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, state_1 = thread_seed);
|
||||
cpa!(scope, state_1 += seed_1);
|
||||
|
||||
let state_2 = scope.create_local(Elem::UInt);
|
||||
let state_2 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, state_2 = thread_seed);
|
||||
cpa!(scope, state_2 += seed_2);
|
||||
|
||||
let state_3 = scope.create_local(Elem::UInt);
|
||||
let state_3 = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, state_3 = thread_seed);
|
||||
cpa!(scope, state_3 += seed_3);
|
||||
|
||||
|
@ -260,7 +260,7 @@ fn taus_step(
|
|||
s3: Variable,
|
||||
m: Variable,
|
||||
) {
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let b = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, b = z << s1);
|
||||
cpa!(scope, b = b ^ z);
|
||||
cpa!(scope, b = b >> s2);
|
||||
|
|
|
@ -2,6 +2,7 @@ use burn_tensor::Shape;
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, FloatKind, Scope, Variable},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -34,7 +35,14 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
|
|||
output: Variable,
|
||||
) {
|
||||
let float_elem = Elem::Float(FloatKind::F32);
|
||||
let prob = args[0];
|
||||
let mut prob = args[0];
|
||||
|
||||
if prob.item.elem() != Elem::Float(FloatKind::F32) {
|
||||
let prob_f32 = scope.create_local(float_elem);
|
||||
cpa!(scope, prob_f32 = cast(prob));
|
||||
prob = prob_f32;
|
||||
}
|
||||
|
||||
cpa!(
|
||||
scope,
|
||||
range(0u32, n_values_per_thread).for_each(|i, scope| {
|
||||
|
@ -43,7 +51,7 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
|
|||
taus_step_2(scope, state_2);
|
||||
lcg_step(scope, state_3);
|
||||
|
||||
let int_random = scope.create_local(Elem::UInt);
|
||||
let int_random = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, int_random = state_0 ^ state_1);
|
||||
cpa!(scope, int_random = int_random ^ state_2);
|
||||
cpa!(scope, int_random = int_random ^ state_3);
|
||||
|
@ -54,7 +62,7 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
|
|||
let bernoulli = scope.create_local(Elem::Bool);
|
||||
cpa!(scope, bernoulli = float_random < prob);
|
||||
|
||||
let write_index = scope.create_local(Elem::UInt);
|
||||
let write_index = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, write_index = i * n_invocations);
|
||||
cpa!(scope, write_index += write_index_base);
|
||||
cpa!(scope, output[write_index] = bernoulli);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, FloatKind, Scope, Variable},
|
||||
prelude::*,
|
||||
};
|
||||
use std::f32::consts::PI;
|
||||
|
||||
|
@ -47,7 +48,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
|
|||
cpa!(
|
||||
scope,
|
||||
range(0u32, n_values_per_thread / 2).for_each(|i, scope| {
|
||||
let int_random = scope.create_local(Elem::UInt);
|
||||
let int_random = scope.create_local(u32::as_elem());
|
||||
|
||||
// First random uniform integer
|
||||
taus_step_0(scope, state_0);
|
||||
|
@ -95,9 +96,9 @@ impl<E: JitElement> Prng<E> for Normal<E> {
|
|||
cpa!(scope, normal_1 += mean);
|
||||
|
||||
// Write to output
|
||||
let write_index_0 = scope.create_local(Elem::UInt);
|
||||
let write_index_1 = scope.create_local(Elem::UInt);
|
||||
let iteration_offset = scope.create_local(Elem::UInt);
|
||||
let write_index_0 = scope.create_local(u32::as_elem());
|
||||
let write_index_1 = scope.create_local(u32::as_elem());
|
||||
let iteration_offset = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, write_index_0 = write_index_base);
|
||||
cpa!(scope, iteration_offset = two * i);
|
||||
cpa!(scope, iteration_offset *= n_invocations);
|
||||
|
|
|
@ -2,6 +2,7 @@ use burn_tensor::Shape;
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, FloatKind, Scope, Variable},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -49,7 +50,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
|
|||
taus_step_2(scope, state_2);
|
||||
lcg_step(scope, state_3);
|
||||
|
||||
let int_random = scope.create_local(Elem::UInt);
|
||||
let int_random = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, int_random = state_0 ^ state_1);
|
||||
cpa!(scope, int_random = int_random ^ state_2);
|
||||
cpa!(scope, int_random = int_random ^ state_3);
|
||||
|
@ -65,7 +66,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
|
|||
cpa!(scope, uniform = cast(uniform_float));
|
||||
cpa!(scope, uniform += lower_bound);
|
||||
|
||||
let write_index = scope.create_local(Elem::UInt);
|
||||
let write_index = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, write_index = i * n_invocations);
|
||||
cpa!(scope, write_index += write_index_base);
|
||||
cpa!(scope, output[write_index] = uniform);
|
||||
|
|
|
@ -3,6 +3,7 @@ use burn_tensor::cast::ToElement;
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, Item, Scope, Variable},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use super::base::ReduceDimShared;
|
||||
|
@ -17,7 +18,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
input_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(u32::as_elem(), shared_memory_size);
|
||||
let max = input_item
|
||||
.elem()
|
||||
.constant_from_f64(ToElement::to_f64(&E::minimum_value()));
|
||||
|
|
|
@ -2,6 +2,7 @@ use burn_tensor::cast::ToElement;
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, Item, Scope, Variable},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{kernel::reduce::Argmin, JitElement};
|
||||
|
@ -18,7 +19,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
input_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(u32::as_elem(), shared_memory_size);
|
||||
let min = input_item
|
||||
.elem()
|
||||
.constant_from_f64(ToElement::to_f64(&E::maximum_value()));
|
||||
|
|
|
@ -2,6 +2,7 @@ use cubecl::{
|
|||
cpa,
|
||||
ir::{Builtin, KernelDefinition, VariableKind},
|
||||
prelude::CubeCount,
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -120,38 +121,38 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
let cube_dim_x = Variable::builtin(Builtin::CubeDimX);
|
||||
let cube_dim_y = Variable::builtin(Builtin::CubeDimY);
|
||||
|
||||
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
|
||||
let stride_reduce_dim_input = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
|
||||
let shape_reduce_dim_input = scope.create_local(Elem::UInt);
|
||||
let shape_reduce_dim_input = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, shape_reduce_dim_input = shape(tensor, dim));
|
||||
|
||||
// To determine which reduce_group (not position, but absolute id)
|
||||
let reduce_group_id = scope.create_local(Elem::UInt);
|
||||
let reduce_group_id = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x);
|
||||
cpa!(scope, reduce_group_id += cube_pos_x);
|
||||
|
||||
// nth thread in the cube
|
||||
let local_id = scope.create_local(Elem::UInt);
|
||||
let local_id = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, local_id = local_invocation_id_y * cube_dim_x);
|
||||
cpa!(scope, local_id += local_invocation_id_x);
|
||||
|
||||
let n_threads = scope.create_local(Elem::UInt);
|
||||
let n_threads = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, n_threads = cube_dim_x * cube_dim_y);
|
||||
|
||||
let index_offset = scope.zero(Elem::UInt);
|
||||
let index_offset = scope.zero(u32::as_elem());
|
||||
|
||||
cpa!(
|
||||
scope,
|
||||
range(0u32, rank).for_each(|i, scope| {
|
||||
let stride_input = scope.create_local(Elem::UInt);
|
||||
let stride_output = scope.create_local(Elem::UInt);
|
||||
let shape_output = scope.create_local(Elem::UInt);
|
||||
let stride_input = scope.create_local(u32::as_elem());
|
||||
let stride_output = scope.create_local(u32::as_elem());
|
||||
let shape_output = scope.create_local(u32::as_elem());
|
||||
|
||||
cpa!(scope, stride_input = stride(tensor, i));
|
||||
cpa!(scope, stride_output = stride(output, i));
|
||||
cpa!(scope, shape_output = shape(output, i));
|
||||
|
||||
let num_block = scope.create_local(Elem::UInt);
|
||||
let num_block = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, num_block = reduce_group_id / stride_output);
|
||||
cpa!(scope, num_block = num_block % shape_output);
|
||||
cpa!(scope, num_block = num_block * stride_input);
|
||||
|
@ -166,14 +167,14 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
cpa!(
|
||||
scope,
|
||||
range(0u32, self.n_input_values_per_thread).for_each(|i, scope| {
|
||||
let nth = scope.create_local(Elem::UInt);
|
||||
let nth = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, nth = i * n_threads);
|
||||
cpa!(scope, nth += local_id);
|
||||
|
||||
let within_shape = scope.create_local(Elem::Bool);
|
||||
|
||||
if self.divisible_shape {
|
||||
let current_position = scope.create_local(Elem::UInt);
|
||||
let current_position = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, current_position = nth * stride_reduce_dim_input);
|
||||
cpa!(scope, current_position += index_offset);
|
||||
|
||||
|
@ -182,7 +183,7 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
} else {
|
||||
cpa!(scope, within_shape = nth < shape_reduce_dim_input);
|
||||
cpa!(scope, if(within_shape).then(|scope|{
|
||||
let current_position = scope.create_local(Elem::UInt);
|
||||
let current_position = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, current_position = nth * stride_reduce_dim_input);
|
||||
cpa!(scope, current_position += index_offset);
|
||||
|
||||
|
@ -208,7 +209,7 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
let updating_thread = scope.create_local(Elem::Bool);
|
||||
cpa!(scope, updating_thread = local_id < n_threads);
|
||||
cpa!(scope, if(updating_thread).then(|scope|{
|
||||
let read_position = scope.create_local(Elem::UInt);
|
||||
let read_position = scope.create_local(u32::as_elem());
|
||||
cpa!(scope, read_position = n_threads + local_id);
|
||||
|
||||
let read_value = RD::read_from_shared(scope, shared_memory, read_position);
|
||||
|
|
|
@ -44,6 +44,7 @@ impl<RD: ReduceDimAlgorithm<EI>, R: JitRuntime, EI: JitElement, EO: JitElement>
|
|||
&input.shape,
|
||||
&input.strides,
|
||||
reduce_dim,
|
||||
EI::dtype(),
|
||||
)),
|
||||
input,
|
||||
output,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp::min, fmt::Display};
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::{DType, Shape};
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
|
||||
/// Autotune key representative of reduce versions
|
||||
|
@ -9,14 +9,15 @@ pub struct ReduceAutotuneKey {
|
|||
pub(crate) reduce_dim_length: usize,
|
||||
pub(crate) reduce_dim_stride: usize,
|
||||
pub(crate) others_product: usize,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for ReduceAutotuneKey {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"Reduce - reduce_dim_length: {:?} reduce_dim_stride: {:?} others_product: {:?}",
|
||||
self.reduce_dim_length, self.reduce_dim_stride, self.others_product
|
||||
"Reduce - reduce_dim_length: {:?} reduce_dim_stride: {:?} others_product: {:?} dtype: {:?}",
|
||||
self.reduce_dim_length, self.reduce_dim_stride, self.others_product, self.dtype
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
|
@ -25,7 +26,7 @@ impl Display for ReduceAutotuneKey {
|
|||
|
||||
impl ReduceAutotuneKey {
|
||||
/// Create a reduce autotune key from the input shape and reduce dim
|
||||
pub fn new(shape: &Shape, strides: &[usize], reduce_dim: usize) -> Self {
|
||||
pub fn new(shape: &Shape, strides: &[usize], reduce_dim: usize, dtype: DType) -> Self {
|
||||
let ndims = strides.len();
|
||||
let reduce_dim_length = shape.dims[reduce_dim];
|
||||
let reduce_dim_stride = strides[reduce_dim];
|
||||
|
@ -39,6 +40,7 @@ impl ReduceAutotuneKey {
|
|||
reduce_dim_length: anchor(reduce_dim_length, None),
|
||||
reduce_dim_stride: anchor(reduce_dim_stride, None),
|
||||
others_product: anchor(others_product, None),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,14 @@ pub(crate) async fn into_data<R: JitRuntime, E: JitElement>(tensor: JitTensor<R,
|
|||
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
|
||||
}
|
||||
|
||||
#[allow(unused, reason = "useful for debugging kernels")]
|
||||
pub(crate) fn into_data_sync<R: JitRuntime, E: JitElement>(tensor: JitTensor<R, E>) -> TensorData {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
|
||||
let bytes = tensor.client.read(tensor.handle.binding());
|
||||
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
|
||||
}
|
||||
|
||||
pub(crate) async fn bool_into_data<R: JitRuntime>(tensor: JitTensor<R, u32>) -> TensorData {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
let bytes = tensor.client.read_async(tensor.handle.binding()).await;
|
||||
|
|
|
@ -37,8 +37,12 @@ pub use serial_test;
|
|||
#[macro_export]
|
||||
macro_rules! testgen_all {
|
||||
() => {
|
||||
use burn_tensor::{Float, Int};
|
||||
$crate::testgen_all!([Float], [Int]);
|
||||
};
|
||||
([$($float:ident),*], [$($int:ident),*]) => {
|
||||
mod jit {
|
||||
burn_jit::testgen_jit!();
|
||||
burn_jit::testgen_jit!([$($float),*], [$($int),*]);
|
||||
|
||||
mod kernel {
|
||||
use super::*;
|
||||
|
@ -79,7 +83,7 @@ macro_rules! testgen_all {
|
|||
}
|
||||
}
|
||||
mod jit_fusion {
|
||||
burn_jit::testgen_jit_fusion!();
|
||||
burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*]);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -87,22 +91,32 @@ macro_rules! testgen_all {
|
|||
#[macro_export]
|
||||
macro_rules! testgen_jit {
|
||||
() => {
|
||||
use super::*;
|
||||
use burn_tensor::{Float, Int};
|
||||
$crate::testgen_jit!([Float], [Int]);
|
||||
};
|
||||
([$($float:ident),*], [$($int:ident),*]) => {
|
||||
pub use super::*;
|
||||
use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test};
|
||||
|
||||
pub type TestBackend = JitBackend<TestRuntime, f32, i32>;
|
||||
pub type TestBackend2<F, I> = JitBackend<TestRuntime, F, I>;
|
||||
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
pub type TestTensor2<F, I, const D: usize> = burn_tensor::Tensor<TestBackend2<F, I>, D>;
|
||||
pub type TestTensorInt<const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
pub type TestTensorInt2<F, I, const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Int>;
|
||||
pub type TestTensorBool<const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;
|
||||
pub type TestTensorBool2<F, I, const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Bool>;
|
||||
|
||||
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
|
||||
|
||||
burn_tensor::testgen_all!();
|
||||
burn_autodiff::testgen_all!();
|
||||
burn_tensor::testgen_all!([$($float),*], [$($int),*]);
|
||||
burn_autodiff::testgen_all!([$($float),*]);
|
||||
|
||||
// Not all ops are implemented for quantization yet, notably missing:
|
||||
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
|
||||
|
@ -111,28 +125,38 @@ macro_rules! testgen_jit {
|
|||
burn_tensor::testgen_calibration!();
|
||||
burn_tensor::testgen_scheme!();
|
||||
burn_tensor::testgen_quantize!();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! testgen_jit_fusion {
|
||||
() => {
|
||||
use burn_tensor::{Float, Int};
|
||||
$crate::testgen_jit_fusion!([Float], [Int]);
|
||||
};
|
||||
([$($float:ident),*], [$($int:ident),*]) => {
|
||||
use super::*;
|
||||
use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor};
|
||||
|
||||
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime, f32, i32>>;
|
||||
pub type TestBackend2<F, I> = burn_fusion::Fusion<JitBackend<TestRuntime, F, I>>;
|
||||
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
pub type TestTensor2<F, I, const D: usize> = burn_tensor::Tensor<TestBackend2<F, I>, D>;
|
||||
pub type TestTensorInt<const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
pub type TestTensorInt2<F, I, const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Int>;
|
||||
pub type TestTensorBool<const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;
|
||||
pub type TestTensorBool2<F, I, const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend2<F, I>, D, burn_tensor::Bool>;
|
||||
|
||||
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
|
||||
|
||||
burn_tensor::testgen_all!();
|
||||
burn_autodiff::testgen_all!();
|
||||
burn_tensor::testgen_all!([$($float),*], [$($int),*]);
|
||||
burn_autodiff::testgen_all!([$($float),*]);
|
||||
|
||||
// Not all ops are implemented for quantization yet, notably missing:
|
||||
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
|
||||
|
|
|
@ -2,20 +2,25 @@
|
|||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Tensor library with user-friendly APIs and automatic differentiation support"
|
||||
documentation = "https://docs.rs/burn-tensor"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-tensor"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor"
|
||||
documentation = "https://docs.rs/burn-tensor"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
cubecl = ["dep:cubecl"]
|
||||
cubecl-cuda = ["cubecl", "cubecl/cuda"]
|
||||
cubecl-hip = ["cubecl", "cubecl/hip"]
|
||||
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
|
||||
default = ["std", "repr"]
|
||||
doc = ["default"]
|
||||
experimental-named-tensor = []
|
||||
export_tests = ["burn-tensor-testgen"]
|
||||
export_tests = ["burn-tensor-testgen", "cubecl"]
|
||||
repr = []
|
||||
std = [
|
||||
"rand/std",
|
||||
"half/std",
|
||||
|
@ -24,24 +29,19 @@ std = [
|
|||
"burn-common/rayon",
|
||||
"colored",
|
||||
]
|
||||
repr = []
|
||||
cubecl = ["dep:cubecl"]
|
||||
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
|
||||
cubecl-cuda = ["cubecl", "cubecl/cuda"]
|
||||
cubecl-hip = ["cubecl", "cubecl/hip"]
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.16.0", default-features = false }
|
||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true }
|
||||
cubecl = { workspace = true, optional = true }
|
||||
cubecl = { workspace = true, optional = true, default-features = true }
|
||||
|
||||
bytemuck = { workspace = true }
|
||||
colored = { workspace = true, optional = true }
|
||||
derive-new = { workspace = true }
|
||||
half = { workspace = true, features = ["bytemuck"] }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
|
||||
bytemuck = { workspace = true }
|
||||
colored = { workspace = true, optional = true }
|
||||
|
||||
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
|
||||
hashbrown = { workspace = true } # no_std compatible
|
||||
|
@ -50,6 +50,7 @@ hashbrown = { workspace = true } # no_std compatible
|
|||
serde = { workspace = true }
|
||||
serde_bytes = { workspace = true }
|
||||
|
||||
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
portable-atomic-util = { workspace = true }
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ pub mod repr;
|
|||
|
||||
#[cfg(feature = "export_tests")]
|
||||
#[allow(missing_docs)]
|
||||
mod tests;
|
||||
pub mod tests;
|
||||
|
||||
pub use half::{bf16, f16};
|
||||
pub(crate) use tensor::check::macros::check;
|
||||
|
@ -30,7 +30,7 @@ pub use burn_common::reader::*; // Useful so that backends don't have to add `bu
|
|||
|
||||
#[cfg(feature = "cubecl")]
|
||||
mod cube {
|
||||
use cubecl::ir::{Elem, FloatKind, IntKind};
|
||||
use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind};
|
||||
|
||||
impl From<crate::DType> for cubecl::ir::Elem {
|
||||
fn from(dtype: crate::DType) -> Self {
|
||||
|
@ -41,11 +41,12 @@ mod cube {
|
|||
crate::DType::BF16 => Elem::Float(FloatKind::BF16),
|
||||
crate::DType::I64 => Elem::Int(IntKind::I64),
|
||||
crate::DType::I32 => Elem::Int(IntKind::I32),
|
||||
crate::DType::I16 => panic!("i16 isn't supported yet."),
|
||||
crate::DType::I8 => panic!("i8 isn't supported yet."),
|
||||
crate::DType::U64 => Elem::UInt,
|
||||
crate::DType::U32 => Elem::UInt,
|
||||
crate::DType::U8 => panic!("u8 isn't supported yet."),
|
||||
crate::DType::I16 => Elem::Int(IntKind::I16),
|
||||
crate::DType::I8 => Elem::Int(IntKind::I8),
|
||||
crate::DType::U64 => Elem::UInt(UIntKind::U64),
|
||||
crate::DType::U32 => Elem::UInt(UIntKind::U32),
|
||||
crate::DType::U16 => Elem::UInt(UIntKind::U16),
|
||||
crate::DType::U8 => Elem::UInt(UIntKind::U8),
|
||||
crate::DType::Bool => Elem::Bool,
|
||||
crate::DType::QFloat(_) => panic!("quantized type is not supported yet."),
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use core::any::{Any, TypeId};
|
||||
use core::{
|
||||
any::{Any, TypeId},
|
||||
f32,
|
||||
};
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use alloc::format;
|
||||
|
@ -181,6 +184,11 @@ impl TensorData {
|
|||
.map(|e: &i64| e.elem::<E>()),
|
||||
),
|
||||
DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
|
||||
DType::U16 => Box::new(
|
||||
bytemuck::checked::cast_slice(&self.bytes)
|
||||
.iter()
|
||||
.map(|e: &u16| e.elem::<E>()),
|
||||
),
|
||||
DType::U32 => Box::new(
|
||||
bytemuck::checked::cast_slice(&self.bytes)
|
||||
.iter()
|
||||
|
@ -313,6 +321,7 @@ impl TensorData {
|
|||
DType::I8 => self.convert_inplace::<i8, E>(),
|
||||
DType::U64 => self.convert_inplace::<u64, E>(),
|
||||
DType::U32 => self.convert_inplace::<u32, E>(),
|
||||
DType::U16 => self.convert_inplace::<u16, E>(),
|
||||
DType::U8 => self.convert_inplace::<u8, E>(),
|
||||
DType::Bool | DType::QFloat(_) => unreachable!(),
|
||||
}
|
||||
|
@ -419,6 +428,7 @@ impl TensorData {
|
|||
DType::I8 => self.assert_eq_elem::<i8>(other),
|
||||
DType::U64 => self.assert_eq_elem::<u64>(other),
|
||||
DType::U32 => self.assert_eq_elem::<u32>(other),
|
||||
DType::U16 => self.assert_eq_elem::<u16>(other),
|
||||
DType::U8 => self.assert_eq_elem::<u8>(other),
|
||||
DType::Bool => self.assert_eq_elem::<bool>(other),
|
||||
DType::QFloat(q) => {
|
||||
|
@ -511,9 +521,21 @@ impl TensorData {
|
|||
continue;
|
||||
}
|
||||
|
||||
let err = ((a - b).pow(2.0f64)).sqrt();
|
||||
let err = (a - b).abs();
|
||||
|
||||
if err > tolerance || err.is_nan() {
|
||||
if self.dtype.is_float() {
|
||||
if let Some((err, tolerance)) = compare_floats(a, b, self.dtype, tolerance) {
|
||||
// Only print the first 5 different values.
|
||||
if num_diff < max_num_diff {
|
||||
message += format!(
|
||||
"\n => Position {i}: {a} != {b} | difference {err} > tolerance \
|
||||
{tolerance}"
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
num_diff += 1;
|
||||
}
|
||||
} else if err > tolerance || err.is_nan() {
|
||||
// Only print the first 5 different values.
|
||||
if num_diff < max_num_diff {
|
||||
message += format!(
|
||||
|
@ -683,6 +705,7 @@ impl core::fmt::Display for TensorData {
|
|||
DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
|
||||
DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
|
||||
DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
|
||||
DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
|
||||
DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
|
||||
DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
|
||||
DType::QFloat(q) => match &q {
|
||||
|
@ -869,7 +892,7 @@ impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
|
|||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E, D> {
|
||||
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data<E, D> {
|
||||
/// Asserts the data is approximately equal to another data.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -926,9 +949,21 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
|
|||
continue;
|
||||
}
|
||||
|
||||
let err = ((a - b).pow(2.0f64)).sqrt();
|
||||
let err = (a - b).abs();
|
||||
|
||||
if err > tolerance || err.is_nan() {
|
||||
if E::dtype().is_float() {
|
||||
if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) {
|
||||
// Only print the first 5 different values.
|
||||
if num_diff < max_num_diff {
|
||||
message += format!(
|
||||
"\n => Position {i}: {a} != {b} | difference {err} > tolerance \
|
||||
{tolerance}"
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
num_diff += 1;
|
||||
}
|
||||
} else if err > tolerance || err.is_nan() {
|
||||
// Only print the first 5 different values.
|
||||
if num_diff < max_num_diff {
|
||||
message += format!(
|
||||
|
@ -1076,6 +1111,30 @@ impl<E: core::fmt::Debug, const D: usize> core::fmt::Display for Data<E, D> {
|
|||
}
|
||||
}
|
||||
|
||||
fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> {
|
||||
let epsilon_deviations = tolerance / f32::EPSILON as f64;
|
||||
let epsilon = match ty {
|
||||
DType::F64 => f32::EPSILON as f64, // Don't increase precision beyond `f32`, see below
|
||||
DType::F32 => f32::EPSILON as f64,
|
||||
DType::F16 => half::f16::EPSILON.to_f64(),
|
||||
DType::BF16 => half::bf16::EPSILON.to_f64(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let tolerance_norm = epsilon_deviations * epsilon;
|
||||
// Clamp to 1.0 so we don't require more precision than `tolerance`. This is because literals
|
||||
// have a fixed number of digits, so increasing precision breaks things
|
||||
let value_abs = value.abs().max(1.0);
|
||||
let tolerance_adjusted = tolerance_norm * value_abs;
|
||||
|
||||
let err = (value - other).abs();
|
||||
|
||||
if err > tolerance_adjusted || err.is_nan() {
|
||||
Some((err, tolerance_adjusted))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(deprecated)]
|
||||
mod tests {
|
||||
|
@ -1140,16 +1199,16 @@ mod tests {
|
|||
#[test]
|
||||
fn should_assert_appox_eq_limit() {
|
||||
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
|
||||
let data2 = TensorData::from([[3.01, 5.0, 6.0]]);
|
||||
let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
|
||||
|
||||
data1.assert_approx_eq(&data2, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_assert_appox_eq_above_limit() {
|
||||
fn should_assert_approx_eq_above_limit() {
|
||||
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
|
||||
let data2 = TensorData::from([[3.011, 5.0, 6.0]]);
|
||||
let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
|
||||
|
||||
data1.assert_approx_eq(&data2, 2);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use core::cmp::Ordering;
|
||||
|
||||
use crate::{cast::ToElement, quantization::QuantizationStrategy, Distribution};
|
||||
#[cfg(feature = "cubecl")]
|
||||
use cubecl::flex32;
|
||||
use half::{bf16, f16};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -193,6 +195,14 @@ make_element!(
|
|||
dtype DType::I16
|
||||
);
|
||||
|
||||
make_element!(
|
||||
ty u16 Precision::Half,
|
||||
convert |elem: &dyn ToElement| elem.to_u16(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &u16, b: &u16| Ord::cmp(a, b),
|
||||
dtype DType::U16
|
||||
);
|
||||
|
||||
make_element!(
|
||||
ty i8 Precision::Other,
|
||||
convert |elem: &dyn ToElement| elem.to_i8(),
|
||||
|
@ -230,6 +240,18 @@ make_element!(
|
|||
dtype DType::BF16
|
||||
);
|
||||
|
||||
#[cfg(feature = "cubecl")]
|
||||
make_element!(
|
||||
ty flex32 Precision::Half,
|
||||
convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
|
||||
random |distribution: Distribution, rng: &mut R| {
|
||||
let sample: f32 = distribution.sampler(rng).sample();
|
||||
flex32::from_elem(sample)
|
||||
},
|
||||
cmp |a: &flex32, b: &flex32| a.total_cmp(b),
|
||||
dtype DType::F32
|
||||
);
|
||||
|
||||
make_element!(
|
||||
ty bool Precision::Other,
|
||||
convert |elem: &dyn ToElement| elem.to_u8() != 0,
|
||||
|
@ -254,6 +276,7 @@ pub enum DType {
|
|||
I8,
|
||||
U64,
|
||||
U32,
|
||||
U16,
|
||||
U8,
|
||||
Bool,
|
||||
QFloat(QuantizationStrategy),
|
||||
|
@ -273,6 +296,7 @@ impl DType {
|
|||
DType::I8 => core::mem::size_of::<i8>(),
|
||||
DType::U64 => core::mem::size_of::<u64>(),
|
||||
DType::U32 => core::mem::size_of::<u32>(),
|
||||
DType::U16 => core::mem::size_of::<u16>(),
|
||||
DType::U8 => core::mem::size_of::<u8>(),
|
||||
DType::Bool => core::mem::size_of::<bool>(),
|
||||
DType::QFloat(strategy) => match strategy {
|
||||
|
|
|
@ -442,6 +442,50 @@ impl ToElement for bf16 {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cubecl")]
|
||||
impl ToElement for cubecl::flex32 {
|
||||
#[inline]
|
||||
fn to_i64(&self) -> i64 {
|
||||
Self::to_f32(*self).to_i64()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u64(&self) -> u64 {
|
||||
Self::to_f32(*self).to_u64()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i8(&self) -> i8 {
|
||||
Self::to_f32(*self).to_i8()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u8(&self) -> u8 {
|
||||
Self::to_f32(*self).to_u8()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i16(&self) -> i16 {
|
||||
Self::to_f32(*self).to_i16()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u16(&self) -> u16 {
|
||||
Self::to_f32(*self).to_u16()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i32(&self) -> i32 {
|
||||
Self::to_f32(*self).to_i32()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u32(&self) -> u32 {
|
||||
Self::to_f32(*self).to_u32()
|
||||
}
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
Self::to_f32(*self)
|
||||
}
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
Self::to_f64(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToElement for bool {
|
||||
#[inline]
|
||||
fn to_i64(&self) -> i64 {
|
||||
|
|
|
@ -15,7 +15,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_hard_sigmoid_overflow() {
|
||||
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
|
||||
let tensor = TestTensor::<1>::from([FloatType::MAX, FloatType::MIN]);
|
||||
|
||||
let output = activation::hard_sigmoid(tensor, 0.2, 0.5);
|
||||
let expected = TensorData::from([1.0, 0.0]);
|
||||
|
|
|
@ -9,9 +9,9 @@ mod tests {
|
|||
|
||||
let output = activation::leaky_relu(tensor, 0.01);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]),
|
||||
false,
|
||||
);
|
||||
// Account for conversion errors if `FloatType != f32`
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq(&TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]), 5);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(log_sigmoid)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Tensor, TensorData};
|
||||
use burn_tensor::{activation, cast::ToElement, tests::Numeric, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_log_sigmoid() {
|
||||
|
@ -10,7 +10,7 @@ mod tests {
|
|||
let output = activation::log_sigmoid(tensor);
|
||||
let expected = TensorData::from([[-3.132617e-1, -9.114665e-4], [-2.260327e-6, -3.0485873]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -23,9 +23,10 @@ mod tests {
|
|||
let expected = TensorData::from([0.0, -300.0]);
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
|
||||
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
|
||||
let tensor =
|
||||
TestTensor::<1>::from([<FloatType as Numeric>::MAX, <FloatType as Numeric>::MIN]);
|
||||
let output = activation::log_sigmoid(tensor);
|
||||
let expected = TensorData::from([0.0, f32::MIN]);
|
||||
let expected = TensorData::from([0.0, <FloatType as Numeric>::MIN.to_f32()]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
|
|
@ -11,6 +11,6 @@ mod tests {
|
|||
let output = activation::mish(tensor);
|
||||
let expected = TensorData::from([[-0.1971, -0.3006, -0.1172], [-0.2413, 0.5823, -0.0888]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(sigmoid)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Tensor, TensorData};
|
||||
use burn_tensor::{activation, tests::Numeric, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid() {
|
||||
|
@ -10,12 +10,12 @@ mod tests {
|
|||
let output = activation::sigmoid(tensor);
|
||||
let expected = TensorData::from([[0.7311, 0.9991], [1.0, 0.0474]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid_overflow() {
|
||||
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
|
||||
let tensor = TestTensor::<1>::from([FloatType::MAX, FloatType::MIN]);
|
||||
|
||||
let output = activation::sigmoid(tensor);
|
||||
let expected = TensorData::from([1.0, 0.0]);
|
||||
|
|
|
@ -10,6 +10,6 @@ mod tests {
|
|||
let output = activation::silu(tensor);
|
||||
let expected = TensorData::from([[0.7311, 1.7616], [2.8577, 3.9281]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,6 @@ mod tests {
|
|||
let output = activation::softmax(tensor, 1);
|
||||
let expected = TensorData::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,6 @@ mod tests {
|
|||
let output = activation::softmin(tensor, 1);
|
||||
let expected = TensorData::from([[9.9753e-01, 2.4726e-03], [1.1254e-07, 1.0000e+00]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,19 +5,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_softplus_d2() {
|
||||
let tensor = Tensor::<TestBackend, 2>::from([
|
||||
[-0.4240, -0.9574, -0.2215],
|
||||
[-0.5767, 0.7218, -0.1620],
|
||||
]);
|
||||
let tensor =
|
||||
TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);
|
||||
|
||||
let output = activation::softplus(tensor.clone(), 1.0);
|
||||
let expected = TensorData::from([[0.5034, 0.3249, 0.5885], [0.4458, 1.1178, 0.6154]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
|
||||
let output = activation::softplus(tensor, 2.0);
|
||||
let expected = TensorData::from([[0.1782, 0.0687, 0.2480], [0.1371, 0.8277, 0.2721]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,6 @@ mod tests {
|
|||
let output = activation::tanh(tensor);
|
||||
let expected = TensorData::from([[0.7616, 0.9640], [0.9951, 0.9993]]);
|
||||
|
||||
output.into_data().assert_approx_eq(&expected, 4);
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,122 +5,59 @@ mod ops;
|
|||
mod quantization;
|
||||
mod stats;
|
||||
|
||||
pub use cubecl::prelude::{Float, Int, Numeric};
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_all {
|
||||
() => {
|
||||
// test activation
|
||||
burn_tensor::testgen_gelu!();
|
||||
burn_tensor::testgen_mish!();
|
||||
burn_tensor::testgen_relu!();
|
||||
burn_tensor::testgen_leaky_relu!();
|
||||
burn_tensor::testgen_softmax!();
|
||||
burn_tensor::testgen_softmin!();
|
||||
burn_tensor::testgen_softplus!();
|
||||
burn_tensor::testgen_sigmoid!();
|
||||
burn_tensor::testgen_log_sigmoid!();
|
||||
burn_tensor::testgen_silu!();
|
||||
burn_tensor::testgen_tanh_activation!();
|
||||
pub mod tensor {
|
||||
pub use super::*;
|
||||
|
||||
// test module
|
||||
burn_tensor::testgen_module_forward!();
|
||||
burn_tensor::testgen_module_conv1d!();
|
||||
burn_tensor::testgen_module_conv2d!();
|
||||
burn_tensor::testgen_module_conv3d!();
|
||||
burn_tensor::testgen_module_deform_conv2d!();
|
||||
burn_tensor::testgen_module_conv_transpose1d!();
|
||||
burn_tensor::testgen_module_conv_transpose2d!();
|
||||
burn_tensor::testgen_module_conv_transpose3d!();
|
||||
burn_tensor::testgen_module_unfold4d!();
|
||||
burn_tensor::testgen_module_max_pool1d!();
|
||||
burn_tensor::testgen_module_max_pool2d!();
|
||||
burn_tensor::testgen_module_avg_pool1d!();
|
||||
burn_tensor::testgen_module_avg_pool2d!();
|
||||
burn_tensor::testgen_module_adaptive_avg_pool1d!();
|
||||
burn_tensor::testgen_module_adaptive_avg_pool2d!();
|
||||
burn_tensor::testgen_module_nearest_interpolate!();
|
||||
burn_tensor::testgen_module_bilinear_interpolate!();
|
||||
burn_tensor::testgen_module_bicubic_interpolate!();
|
||||
pub type FloatType = <TestBackend as $crate::backend::Backend>::FloatElem;
|
||||
pub type IntType = <TestBackend as $crate::backend::Backend>::IntElem;
|
||||
pub type BoolType = <TestBackend as $crate::backend::Backend>::BoolTensorPrimitive;
|
||||
|
||||
// test ops
|
||||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arange!();
|
||||
burn_tensor::testgen_arange_step!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cast!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_chunk!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_close!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_create_like!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_flatten!();
|
||||
burn_tensor::testgen_full!();
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
burn_tensor::testgen_init!();
|
||||
burn_tensor::testgen_iter_dim!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_mask!();
|
||||
burn_tensor::testgen_matmul!();
|
||||
burn_tensor::testgen_maxmin!();
|
||||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_narrow!();
|
||||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_one_hot!();
|
||||
burn_tensor::testgen_powf_scalar!();
|
||||
burn_tensor::testgen_random!();
|
||||
burn_tensor::testgen_recip!();
|
||||
burn_tensor::testgen_repeat_dim!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_stack!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_abs!();
|
||||
burn_tensor::testgen_squeeze!();
|
||||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
burn_tensor::testgen_tri!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_any!();
|
||||
burn_tensor::testgen_all_op!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_movedim!();
|
||||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
burn_tensor::testgen_expand!();
|
||||
burn_tensor::testgen_tri_mask!();
|
||||
burn_tensor::testgen_sort_argsort!();
|
||||
burn_tensor::testgen_topk!();
|
||||
burn_tensor::testgen_remainder!();
|
||||
burn_tensor::testgen_cartesian_grid!();
|
||||
burn_tensor::testgen_nan!();
|
||||
burn_tensor::testgen_round!();
|
||||
burn_tensor::testgen_floor!();
|
||||
burn_tensor::testgen_ceil!();
|
||||
$crate::testgen_with_float_param!();
|
||||
$crate::testgen_no_param!();
|
||||
}
|
||||
};
|
||||
([$($float:ident),*], [$($int:ident),*]) => {
|
||||
pub mod tensor {
|
||||
pub use super::*;
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
burn_tensor::testgen_cov!();
|
||||
burn_tensor::testgen_eye!();
|
||||
burn_tensor::testgen_display!();
|
||||
pub type FloatType = <TestBackend as $crate::backend::Backend>::FloatElem;
|
||||
pub type IntType = <TestBackend as $crate::backend::Backend>::IntElem;
|
||||
pub type BoolType = <TestBackend as $crate::backend::Backend>::BoolTensorPrimitive;
|
||||
|
||||
// test clone invariance
|
||||
burn_tensor::testgen_clone_invariance!();
|
||||
::paste::paste! {
|
||||
$(mod [<$float _ty>] {
|
||||
pub use super::*;
|
||||
|
||||
// test padding
|
||||
burn_tensor::testgen_padding!();
|
||||
pub type TestBackend = TestBackend2<$float, IntType>;
|
||||
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
|
||||
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
|
||||
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;
|
||||
|
||||
pub type FloatType = $float;
|
||||
|
||||
$crate::testgen_with_float_param!();
|
||||
})*
|
||||
$(mod [<$int _ty>] {
|
||||
pub use super::*;
|
||||
|
||||
pub type TestBackend = TestBackend2<FloatType, $int>;
|
||||
pub type TestTensor<const D: usize> = TestTensor2<FloatType, $int, D>;
|
||||
pub type TestTensorInt<const D: usize> = TestTensorInt2<FloatType, $int, D>;
|
||||
pub type TestTensorBool<const D: usize> = TestTensorBool2<FloatType, $int, D>;
|
||||
|
||||
pub type IntType = $int;
|
||||
|
||||
$crate::testgen_with_int_param!();
|
||||
})*
|
||||
}
|
||||
$crate::testgen_no_param!();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -178,3 +115,191 @@ macro_rules! testgen_quantization {
|
|||
burn_tensor::testgen_q_transpose!();
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_with_float_param {
|
||||
() => {
|
||||
// test activation
|
||||
burn_tensor::testgen_gelu!();
|
||||
burn_tensor::testgen_mish!();
|
||||
burn_tensor::testgen_relu!();
|
||||
burn_tensor::testgen_leaky_relu!();
|
||||
burn_tensor::testgen_softmax!();
|
||||
burn_tensor::testgen_softmin!();
|
||||
burn_tensor::testgen_softplus!();
|
||||
burn_tensor::testgen_sigmoid!();
|
||||
burn_tensor::testgen_log_sigmoid!();
|
||||
burn_tensor::testgen_silu!();
|
||||
burn_tensor::testgen_tanh_activation!();
|
||||
|
||||
// test module
|
||||
burn_tensor::testgen_module_conv1d!();
|
||||
burn_tensor::testgen_module_conv2d!();
|
||||
burn_tensor::testgen_module_conv3d!();
|
||||
burn_tensor::testgen_module_forward!();
|
||||
burn_tensor::testgen_module_deform_conv2d!();
|
||||
burn_tensor::testgen_module_conv_transpose1d!();
|
||||
burn_tensor::testgen_module_conv_transpose2d!();
|
||||
burn_tensor::testgen_module_conv_transpose3d!();
|
||||
burn_tensor::testgen_module_unfold4d!();
|
||||
burn_tensor::testgen_module_max_pool1d!();
|
||||
burn_tensor::testgen_module_max_pool2d!();
|
||||
burn_tensor::testgen_module_avg_pool1d!();
|
||||
burn_tensor::testgen_module_avg_pool2d!();
|
||||
burn_tensor::testgen_module_adaptive_avg_pool1d!();
|
||||
burn_tensor::testgen_module_adaptive_avg_pool2d!();
|
||||
burn_tensor::testgen_module_nearest_interpolate!();
|
||||
burn_tensor::testgen_module_bilinear_interpolate!();
|
||||
burn_tensor::testgen_module_bicubic_interpolate!();
|
||||
|
||||
// test ops
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
burn_tensor::testgen_narrow!();
|
||||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arange!();
|
||||
burn_tensor::testgen_arange_step!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cast!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_chunk!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_close!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_create_like!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_flatten!();
|
||||
burn_tensor::testgen_full!();
|
||||
burn_tensor::testgen_init!();
|
||||
burn_tensor::testgen_iter_dim!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_mask!();
|
||||
burn_tensor::testgen_matmul!();
|
||||
burn_tensor::testgen_maxmin!();
|
||||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_one_hot!();
|
||||
burn_tensor::testgen_powf_scalar!();
|
||||
burn_tensor::testgen_random!();
|
||||
burn_tensor::testgen_recip!();
|
||||
burn_tensor::testgen_repeat_dim!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_stack!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_abs!();
|
||||
burn_tensor::testgen_squeeze!();
|
||||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
burn_tensor::testgen_tri!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_any!();
|
||||
burn_tensor::testgen_all_op!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_movedim!();
|
||||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
burn_tensor::testgen_expand!();
|
||||
burn_tensor::testgen_tri_mask!();
|
||||
burn_tensor::testgen_sort_argsort!();
|
||||
burn_tensor::testgen_topk!();
|
||||
burn_tensor::testgen_remainder!();
|
||||
burn_tensor::testgen_cartesian_grid!();
|
||||
burn_tensor::testgen_nan!();
|
||||
burn_tensor::testgen_round!();
|
||||
burn_tensor::testgen_floor!();
|
||||
burn_tensor::testgen_ceil!();
|
||||
burn_tensor::testgen_select!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
burn_tensor::testgen_cov!();
|
||||
burn_tensor::testgen_eye!();
|
||||
|
||||
// test padding
|
||||
burn_tensor::testgen_padding!();
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_with_int_param {
|
||||
() => {
|
||||
// test ops
|
||||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cast!();
|
||||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_expand!();
|
||||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_mask!();
|
||||
burn_tensor::testgen_movedim!();
|
||||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_sign!();
|
||||
burn_tensor::testgen_sort_argsort!();
|
||||
burn_tensor::testgen_stack!();
|
||||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_eye!();
|
||||
|
||||
// test padding
|
||||
burn_tensor::testgen_padding!();
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_no_param {
|
||||
() => {
|
||||
// test stats
|
||||
burn_tensor::testgen_display!();
|
||||
|
||||
// test clone invariance
|
||||
burn_tensor::testgen_clone_invariance!();
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! as_bytes {
|
||||
($ty:ident: $($elem:expr),*) => {
|
||||
F::as_bytes(&[$($ty::new($elem),)*])
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! as_type {
|
||||
($ty:ident: [$($elem:tt),*]) => {
|
||||
[$($crate::as_type![$ty: $elem]),*]
|
||||
};
|
||||
($ty:ident: [$($elem:tt,)*]) => {
|
||||
[$($crate::as_type![$ty: $elem]),*]
|
||||
};
|
||||
($ty:ident: $elem:expr) => {
|
||||
{
|
||||
use $crate::tests::{Float, Int};
|
||||
|
||||
$ty::new($elem)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -106,7 +106,7 @@ mod tests {
|
|||
self.count_include_pad,
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ mod tests {
|
|||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<f32>()
|
||||
.as_slice::<FloatType>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
|
@ -148,7 +148,7 @@ mod tests {
|
|||
InterpolateOptions::new(InterpolateMode::Bicubic),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ mod tests {
|
|||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<f32>()
|
||||
.as_slice::<FloatType>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
|
@ -156,7 +156,7 @@ mod tests {
|
|||
InterpolateOptions::new(InterpolateMode::Bilinear),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[0.9074, 0.6387], [0.5160, 0.4196]],
|
||||
[[2.4259, 1.8008], [1.5449, 1.3112]],
|
||||
[[3.9444, 2.9629], [2.5738, 2.2027]],
|
||||
|
@ -55,7 +55,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([
|
||||
test.assert_output(TestTensor::<4>::from([
|
||||
[
|
||||
[[0.2155, 0.1928], [0.1934, 0.1755]],
|
||||
[[0.7251, 0.6759], [0.6877, 0.6485]],
|
||||
|
@ -93,7 +93,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[0.1018, 0.0658], [0.0467, 0.0362]],
|
||||
[[0.4125, 0.3367], [0.3069, 0.2824]],
|
||||
[[1.3076, 1.0242], [0.9025, 0.8000]],
|
||||
|
@ -123,7 +123,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[1.0794, 0.7676], [0.7209, 0.5337]],
|
||||
[[2.7059, 2.0216], [1.9740, 1.5419]],
|
||||
[[4.3325, 3.2755], [3.2271, 2.5501]],
|
||||
|
@ -153,7 +153,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[1.0669], [0.6329]],
|
||||
[[2.9741], [2.0383]],
|
||||
[[4.8812], [3.4437]],
|
||||
|
@ -180,7 +180,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[
|
||||
[
|
||||
0.1998, 0.3762, 0.5285, 0.6053, 0.3844, 0.1987, 0.0481, 0.0000,
|
||||
|
@ -264,7 +264,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[1.0647], [0.5783]],
|
||||
[[2.9289], [1.8829]],
|
||||
[[4.7931], [3.1875]],
|
||||
|
@ -291,7 +291,7 @@ mod tests {
|
|||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[0.6162], [0.7611], [0.4666]],
|
||||
[[1.8578], [2.2684], [1.6208]],
|
||||
[[3.0994], [3.7757], [2.7749]],
|
||||
|
@ -318,7 +318,7 @@ mod tests {
|
|||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[
|
||||
[0.8909, 0.6016],
|
||||
[1.0697, 0.7186],
|
||||
|
@ -389,29 +389,29 @@ mod tests {
|
|||
out_width,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = Tensor::<TestBackend, 4>::from(
|
||||
let weight = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_weight.num_elements() as f32);
|
||||
let bias = Tensor::<TestBackend, 1>::from(
|
||||
let bias = TestTensor::<1>::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
)
|
||||
.div_scalar(self.channels_out as f32);
|
||||
let x = Tensor::<TestBackend, 4>::from(
|
||||
let x = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_x.num_elements() as f32);
|
||||
let offset = Tensor::<TestBackend, 4>::from(
|
||||
let offset = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_offset.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_offset.num_elements() as f32);
|
||||
let mask = Tensor::<TestBackend, 4>::from(
|
||||
let mask = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_mask.clone())
|
||||
.into_data(),
|
||||
|
@ -433,7 +433,7 @@ mod tests {
|
|||
),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.04);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ mod tests {
|
|||
fn test_embedding_forward() {
|
||||
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indices = TensorData::from([[0, 1], [1, 1]]);
|
||||
let weights = Tensor::<TestBackend, 2>::from(weights);
|
||||
let indices = Tensor::<TestBackend, 2, Int>::from(indices);
|
||||
let weights = TestTensor::<2>::from(weights);
|
||||
let indices = TestTensorInt::<2>::from(indices);
|
||||
|
||||
let output = embedding(weights, indices);
|
||||
let expected = TensorData::from([
|
||||
|
|
|
@ -84,7 +84,7 @@ mod tests {
|
|||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<f32>()
|
||||
.as_slice::<FloatType>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
|
@ -122,7 +122,7 @@ mod tests {
|
|||
InterpolateOptions::new(InterpolateMode::Nearest),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
y.to_data().assert_approx_eq_diff(&output.into_data(), 0.2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,24 +8,24 @@ mod tests {
|
|||
fn test_arange() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..5, &device);
|
||||
let tensor = TestTensorInt::<1>::arange(2..5, &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([2, 3, 4]), false);
|
||||
|
||||
// Test arange with negative numbers
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(-10..-5, &device);
|
||||
let tensor = TestTensorInt::<1>::arange(-10..-5, &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([-10, -9, -8, -7, -6]), false);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(-3..0, &device);
|
||||
let tensor = TestTensorInt::<1>::arange(-3..0, &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([-3, -2, -1]), false);
|
||||
|
||||
// Test arange with a mix of positive and negative numbers
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(-2..3, &device);
|
||||
let tensor = TestTensorInt::<1>::arange(-2..3, &device);
|
||||
tensor
|
||||
.clone()
|
||||
.into_data()
|
||||
|
|
|
@ -9,28 +9,28 @@ mod tests {
|
|||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..9 and the step is 1
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..9, 1, &device);
|
||||
let tensor = TestTensorInt::<1>::arange_step(0..9, 1, &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([0, 1, 2, 3, 4, 5, 6, 7, 8]), false);
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..3 and the step is 2
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 2, &device);
|
||||
let tensor = TestTensorInt::<1>::arange_step(0..3, 2, &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([0, 2]), false);
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..2 and the step is 5
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..2, 5, &device);
|
||||
let tensor = TestTensorInt::<1>::arange_step(0..2, 5, &device);
|
||||
tensor.into_data().assert_eq(&TensorData::from([0]), false);
|
||||
|
||||
// Test correct sequence of numbers when the range includes negative numbers
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-3..3, 2, &device);
|
||||
let tensor = TestTensorInt::<1>::arange_step(-3..3, 2, &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([-3, -1, 1]), false);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-5..1, 5, &device);
|
||||
let tensor = TestTensorInt::<1>::arange_step(-5..1, 5, &device);
|
||||
tensor
|
||||
.clone()
|
||||
.into_data()
|
||||
|
@ -43,6 +43,6 @@ mod tests {
|
|||
fn should_panic_when_step_is_zero() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
// Test that arange_step panics when the step is 0
|
||||
let _tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 0, &device);
|
||||
let _tensor = TestTensorInt::<1>::arange_step(0..3, 0, &device);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,15 +9,14 @@ mod tests {
|
|||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
// Test a single element tensor
|
||||
let tensor: Tensor<TestBackend, 2, Int> =
|
||||
Tensor::<TestBackend, 1, Int>::cartesian_grid([1], &device);
|
||||
let tensor: Tensor<TestBackend, 2, Int> = TestTensorInt::<1>::cartesian_grid([1], &device);
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0]]), false);
|
||||
|
||||
// Test for a 2x2 tensor
|
||||
let tensor: Tensor<TestBackend, 3, Int> =
|
||||
Tensor::<TestBackend, 2, Int>::cartesian_grid([2, 2], &device);
|
||||
TestTensorInt::<2>::cartesian_grid([2, 2], &device);
|
||||
tensor.into_data().assert_eq(
|
||||
&TensorData::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]),
|
||||
false,
|
||||
|
|
|
@ -31,9 +31,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn cast_bool_to_float_tensor() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2, Bool>::from([[true, false, true], [false, false, true]])
|
||||
.float();
|
||||
let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).float();
|
||||
|
||||
let expected = TensorData::from([[1., 0., 1.], [0., 0., 1.]]);
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_cat_ops_int() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 3]], &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data([[4, 5, 6]], &device);
|
||||
let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device);
|
||||
let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device);
|
||||
|
||||
let output = Tensor::cat(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
|
@ -31,8 +31,8 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_cat_ops_bool() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data([[false, true, true]], &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data([[true, true, false]], &device);
|
||||
let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);
|
||||
let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);
|
||||
|
||||
let output = Tensor::cat(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_ceil_ops() {
|
||||
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||
let tensor = TestTensor::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.ceil();
|
||||
let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]);
|
||||
|
|
|
@ -8,7 +8,7 @@ mod tests {
|
|||
let device = Default::default();
|
||||
// test float tensor
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
|
||||
let output = tensor.clamp_min(2.0);
|
||||
|
||||
|
@ -18,7 +18,7 @@ mod tests {
|
|||
|
||||
// test int tensor
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &device);
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp_min(2);
|
||||
|
||||
output
|
||||
|
@ -31,7 +31,7 @@ mod tests {
|
|||
let device = Default::default();
|
||||
// test float tensor
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
|
||||
let output = tensor.clamp_max(2.0);
|
||||
|
||||
|
@ -41,7 +41,7 @@ mod tests {
|
|||
|
||||
// test int tensor
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &device);
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp_max(4);
|
||||
|
||||
output
|
||||
|
@ -54,7 +54,7 @@ mod tests {
|
|||
let device = Default::default();
|
||||
// test float tensor
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp(1.0, 4.0);
|
||||
|
||||
output
|
||||
|
@ -63,7 +63,7 @@ mod tests {
|
|||
|
||||
// test int tensor
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &device);
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp(1, 4);
|
||||
|
||||
output
|
||||
|
|
|
@ -6,7 +6,7 @@ mod tests {
|
|||
#[test]
|
||||
fn should_support_cos_ops() {
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||
let tensor = TestTensor::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.cos();
|
||||
let expected = TensorData::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]);
|
||||
|
|
|
@ -8,8 +8,8 @@ mod tests {
|
|||
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let output = tensor_1 / tensor_2;
|
||||
let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);
|
||||
|
@ -22,8 +22,8 @@ mod tests {
|
|||
let data_1 = TensorData::from([[0.0, 1.0, 2.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let output = tensor_1 / tensor_2;
|
||||
|
||||
|
@ -38,7 +38,7 @@ mod tests {
|
|||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let scalar = 2.0;
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &device);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
|
||||
let output = tensor / scalar;
|
||||
|
||||
|
@ -52,8 +52,8 @@ mod tests {
|
|||
let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let data_2 = TensorData::from([[1, 1, 2], [1, 1, 2]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data(data_1, &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data(data_2, &device);
|
||||
let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);
|
||||
|
||||
let output = tensor_1 / tensor_2;
|
||||
|
||||
|
@ -67,8 +67,8 @@ mod tests {
|
|||
let data_1 = TensorData::from([[0, 1, 2]]);
|
||||
let data_2 = TensorData::from([[1, 1, 2], [3, 4, 5]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data(data_1, &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data(data_2, &device);
|
||||
let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);
|
||||
|
||||
let output = tensor_1 / tensor_2;
|
||||
|
||||
|
@ -81,7 +81,7 @@ mod tests {
|
|||
fn should_support_div_scalar_ops_int() {
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let scalar = 2;
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor / scalar;
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue