mirror of https://github.com/tracel-ai/burn.git
some debugging done
This commit is contained in:
commit
d9b4801448
|
@ -171,7 +171,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -197,7 +197,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -463,6 +463,7 @@ version = "0.14.0"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"dashmap",
|
||||
"data-encoding",
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
|
@ -471,7 +472,6 @@ dependencies = [
|
|||
"serde",
|
||||
"spin",
|
||||
"tokio",
|
||||
"uuid",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
|
@ -546,7 +546,7 @@ dependencies = [
|
|||
"derive-new",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -604,7 +604,7 @@ dependencies = [
|
|||
"derive-new",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -642,7 +642,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
"thiserror",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
|
@ -777,9 +777,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.16.0"
|
||||
version = "1.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5"
|
||||
checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e"
|
||||
dependencies = [
|
||||
"bytemuck_derive",
|
||||
]
|
||||
|
@ -792,7 +792,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -843,7 +843,8 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.5.1"
|
||||
source = "git+https://github.com/huggingface/candle.git?rev=82b641f#82b641fd2752e3b14db6a9c91faef70e3329f3b5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "311d8dbe293aa3b5c34f6a57727fafd67d17a74fa8b65276501237c233b34ffd"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"byteorder",
|
||||
|
@ -869,7 +870,8 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.5.1"
|
||||
source = "git+https://github.com/huggingface/candle.git?rev=82b641f#82b641fd2752e3b14db6a9c91faef70e3329f3b5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3b4b048ca298fb8be90b0f4d0fe68bdca9de956ab52bb6e381463d955f2b661"
|
||||
dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
@ -877,7 +879,8 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.5.1"
|
||||
source = "git+https://github.com/huggingface/candle.git?rev=82b641f#82b641fd2752e3b14db6a9c91faef70e3329f3b5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d31136c9541c82b7de0937c9a58210ada38e17d70810e0eacc0a99d849d848d"
|
||||
dependencies = [
|
||||
"metal 0.27.0",
|
||||
"once_cell",
|
||||
|
@ -1027,7 +1030,7 @@ dependencies = [
|
|||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1439,7 +1442,7 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.10.0",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1450,7 +1453,7 @@ checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
|
|||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1466,6 +1469,12 @@ dependencies = [
|
|||
"parking_lot_core 0.9.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
|
||||
|
||||
[[package]]
|
||||
name = "deflate64"
|
||||
version = "0.1.8"
|
||||
|
@ -1489,7 +1498,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1500,7 +1509,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1521,7 +1530,7 @@ dependencies = [
|
|||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1531,7 +1540,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1542,7 +1551,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1617,7 +1626,7 @@ checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1669,7 +1678,7 @@ dependencies = [
|
|||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1855,7 +1864,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1946,7 +1955,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2879,7 +2888,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3279,7 +3288,7 @@ checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3439,7 +3448,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3701,7 +3710,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3861,7 +3870,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3961,9 +3970,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.85"
|
||||
version = "1.0.86"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
|
||||
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
@ -3984,7 +3993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4535,7 +4544,7 @@ dependencies = [
|
|||
"regex",
|
||||
"relative-path",
|
||||
"rustc_version",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
|
@ -4813,7 +4822,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4880,7 +4889,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5030,7 +5039,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "2ff9eaf853dec4c8802325d8b6d3dffa86cc707fd7a1a4cdbf416e13b061787a"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5059,9 +5068,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
|||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.26.2"
|
||||
version = "0.26.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
|
||||
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
]
|
||||
|
@ -5076,7 +5085,7 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5098,9 +5107,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.66"
|
||||
version = "2.0.68"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
|
||||
checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -5121,7 +5130,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5312,7 +5321,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5450,7 +5459,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5586,7 +5595,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5849,7 +5858,7 @@ dependencies = [
|
|||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
|
@ -5883,7 +5892,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
@ -5957,9 +5966,9 @@ checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
|||
|
||||
[[package]]
|
||||
name = "wgpu"
|
||||
version = "0.20.0"
|
||||
version = "0.20.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32ff1bfee408e1028e2e3acbf6d32d98b08a5a059ccbf5f33305534453ba5d3e"
|
||||
checksum = "90e37c7b9921b75dfd26dd973fdcbce36f13dfa6e2dc82aece584e0ed48c355c"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"cfg-if",
|
||||
|
@ -5983,9 +5992,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wgpu-core"
|
||||
version = "0.20.0"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac6a86eaa5e763e59c73cf9e97d55fffd4dfda69fd8bda19589fcf851ddfef1f"
|
||||
checksum = "d59e0d5fc509601c69e4e1fa06c1eb3c4c9f12956a5e30c79b61ef1c1be7daf0"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"bit-vec",
|
||||
|
@ -6010,9 +6019,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wgpu-hal"
|
||||
version = "0.20.0"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d71c8ae05170583049b65ee562fd839fdc0b3e9ddb84f4e40c9d5f8ea0d4c8c"
|
||||
checksum = "6aa24c3889f885a3fb9133b454c8418bfcfaadcfe4ed3be96ac80e76703b863b"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"arrayvec",
|
||||
|
@ -6309,7 +6318,7 @@ dependencies = [
|
|||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -6392,7 +6401,7 @@ checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
|
@ -6413,7 +6422,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -6433,7 +6442,7 @@ checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
|
@ -6454,7 +6463,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
26
Cargo.toml
26
Cargo.toml
|
@ -27,14 +27,16 @@ license = "MIT OR Apache-2.0"
|
|||
|
||||
[workspace.dependencies]
|
||||
async-trait = "0.1.80"
|
||||
bytemuck = "1.16.0"
|
||||
# candle-core = { version = "0.4.1" }
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", rev = "82b641f" }
|
||||
bytemuck = "1.16.1"
|
||||
candle-core = { version = "0.5.1" }
|
||||
clap = { version = "4.5.7", features = ["derive"] }
|
||||
colored = "2.1.0"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
csv = "1.3.0"
|
||||
dashmap = "5.5.3"
|
||||
data-encoding = { version = "2.6.0", default-features = false, features = [
|
||||
"alloc",
|
||||
] }
|
||||
dirs = "5.0.1"
|
||||
fake = "2.9.2"
|
||||
flate2 = "1.0.30"
|
||||
|
@ -43,16 +45,19 @@ getrandom = { version = "0.2.15", default-features = false }
|
|||
gix-tempfile = { version = "13.1.1", features = ["signals"] }
|
||||
globwalk = "0.9.1"
|
||||
hashbrown = "0.14.5"
|
||||
hound = "3.5.1"
|
||||
image = "0.25.1"
|
||||
indicatif = "0.17.8"
|
||||
js-sys = "0.3.69"
|
||||
libm = "0.2.8"
|
||||
log = { default-features = false, version = "0.4.21" }
|
||||
md5 = "0.7.0"
|
||||
percent-encoding = "2.3.1"
|
||||
pretty_assertions = "1.4.0"
|
||||
proc-macro2 = "1.0.85"
|
||||
proc-macro2 = "1.0.86"
|
||||
protobuf = "3.4.0"
|
||||
protobuf-codegen = "3.4.0"
|
||||
quote = "1.0.36"
|
||||
percent-encoding = "2.3.1"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = { version = "0.24.0" }
|
||||
rayon = "1.10.0"
|
||||
|
@ -64,21 +69,18 @@ rusqlite = { version = "0.31.0" }
|
|||
rust-format = { version = "0.3.4" }
|
||||
sanitize-filename = "0.5.0"
|
||||
serde_rusqlite = "0.35.0"
|
||||
serial_test = "3.1.1"
|
||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||
strum = "0.26.2"
|
||||
strum = "0.26.3"
|
||||
strum_macros = "0.26.4"
|
||||
syn = { version = "2.0.66", features = ["full", "extra-traits"] }
|
||||
syn = { version = "2.0.68", features = ["full", "extra-traits"] }
|
||||
tempfile = "3.10.1"
|
||||
thiserror = "1.0.61"
|
||||
tokio = { version = "1.38.0", features = ["rt", "macros"] }
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-core = "0.1.32"
|
||||
tracing-subscriber = "0.3.18"
|
||||
md5 = "0.7.0"
|
||||
serial_test = "3.1.1"
|
||||
web-time = "1.1.0"
|
||||
hound = "3.5.1"
|
||||
image = "0.25.1"
|
||||
zip = "2.1.3"
|
||||
|
||||
# Terminal UI
|
||||
|
@ -89,7 +91,7 @@ crossterm = "0.27.0"
|
|||
futures-intrusive = "0.5.0"
|
||||
text_placeholder = "0.5.0"
|
||||
pollster = "0.3.0"
|
||||
wgpu = "0.20.0"
|
||||
wgpu = "0.20.1"
|
||||
|
||||
# Benchmarks and Burnbench
|
||||
arboard = "3.4.0"
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# Learner
|
||||
|
||||
The [burn-train](https://github.com/tracel-ai/burn/tree/main/burn-train) crate encapsulates multiple
|
||||
utilities for training deep learning models. The goal of the crate is to provide users with a
|
||||
well-crafted and flexible training loop, so that projects do not have to write such components from
|
||||
the ground up. Most of the interactions with `burn-train` will be with the `LearnerBuilder` struct,
|
||||
briefly presented in the previous [training section](../basic-workflow/training.md). This struct
|
||||
enables you to configure the training loop, offering support for registering metrics, enabling
|
||||
logging, checkpointing states, using multiple devices, and so on.
|
||||
The [burn-train](https://github.com/tracel-ai/burn/tree/main/crates/burn-train) crate encapsulates
|
||||
multiple utilities for training deep learning models. The goal of the crate is to provide users with
|
||||
a well-crafted and flexible training loop, so that projects do not have to write such components
|
||||
from the ground up. Most of the interactions with `burn-train` will be with the `LearnerBuilder`
|
||||
struct, briefly presented in the previous [training section](../basic-workflow/training.md). This
|
||||
struct enables you to configure the training loop, offering support for registering metrics,
|
||||
enabling logging, checkpointing states, using multiple devices, and so on.
|
||||
|
||||
There are still some assumptions in the current provided APIs, which may make them inappropriate for
|
||||
your learning requirements. Indeed, they assume your model will learn from a training dataset and be
|
||||
|
|
|
@ -12,7 +12,7 @@ version.workspace = true
|
|||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["rand/std"]
|
||||
std = ["rand/std", "data-encoding/std"]
|
||||
doc = ["default"]
|
||||
wasm-sync = []
|
||||
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
|
@ -27,10 +27,10 @@ web-time = { version = "1.1.0" }
|
|||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
uuid = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
derive-new = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
data-encoding = { workspace = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
use alloc::string::String;
|
||||
|
||||
use crate::rand::gen_random;
|
||||
use alloc::string::{String, ToString};
|
||||
use uuid::{Builder, Bytes};
|
||||
|
||||
use data_encoding::BASE32_DNSSEC;
|
||||
|
||||
/// Simple ID generator.
|
||||
pub struct IdGenerator {}
|
||||
|
||||
impl IdGenerator {
|
||||
/// Generates a new ID in the form of a UUID.
|
||||
/// Generates a new ID.
|
||||
pub fn generate() -> String {
|
||||
let random_bytes: Bytes = gen_random();
|
||||
// Generate 6 random bytes (281,474,976,710,656 combinations)
|
||||
let random_bytes: [u8; 6] = gen_random();
|
||||
|
||||
let uuid = Builder::from_random_bytes(random_bytes).into_uuid();
|
||||
|
||||
uuid.as_hyphenated().to_string()
|
||||
// Encode the random bytes in base32 DNSSEC
|
||||
// 6 bytes encodes to 10 lower case characters, e.g. "3uu5e6vv7c"
|
||||
BASE32_DNSSEC.encode(&random_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,547 @@
|
|||
use alloc::{
|
||||
borrow::ToOwned,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
use core::any;
|
||||
use core::fmt::{Display, Write};
|
||||
|
||||
/// Default display settings for a module.
|
||||
pub trait ModuleDisplayDefault {
|
||||
/// Attributes of the module used for display purposes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `_content` - The content object that contains display settings and attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional content object containing the display attributes.
|
||||
fn content(&self, _content: Content) -> Option<Content>;
|
||||
|
||||
/// Gets the number of the parameters of the module.
|
||||
fn num_params(&self) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to implement custom display settings for a module.
|
||||
///
|
||||
/// In order to implement custom display settings for a module,
|
||||
/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
|
||||
/// 2. Implement ModuleDisplay trait for the module
|
||||
pub trait ModuleDisplay: ModuleDisplayDefault {
|
||||
/// Formats the module with provided display settings.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `passed_settings` - Display settings passed to the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A string representation of the formatted module.
|
||||
fn format(&self, passed_settings: DisplaySettings) -> String {
|
||||
let settings = if let Some(custom_settings) = self.custom_settings() {
|
||||
custom_settings.inherit(passed_settings)
|
||||
} else {
|
||||
passed_settings
|
||||
};
|
||||
|
||||
let indent = " ".repeat(settings.level * settings.indentation_size());
|
||||
let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
|
||||
|
||||
let settings = settings.level_up();
|
||||
|
||||
let self_type = extract_type_name::<Self>();
|
||||
|
||||
// Use custom content if it is implemented and show_all_attributes is false,
|
||||
// otherwise use default content
|
||||
let content = if !settings.show_all_attributes() {
|
||||
self.custom_content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| {
|
||||
self.content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| {
|
||||
panic!("Default content should be implemented for {self_type}.")
|
||||
})
|
||||
})
|
||||
} else {
|
||||
self.content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
|
||||
};
|
||||
|
||||
let top_level_type = if let Some(top_level_type) = content.top_level_type {
|
||||
top_level_type.to_owned()
|
||||
} else {
|
||||
self_type.to_owned()
|
||||
};
|
||||
|
||||
// If there is only one item in the content, return it or no attributes
|
||||
if let Some(item) = content.single_item {
|
||||
return item;
|
||||
} else if content.attributes.is_empty() {
|
||||
return top_level_type.to_string();
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
|
||||
// Print the struct name
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{} {{", top_level_type).unwrap();
|
||||
} else {
|
||||
write!(result, "{} {{", top_level_type).unwrap();
|
||||
}
|
||||
|
||||
for (i, attribute) in content.attributes.iter().enumerate() {
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
|
||||
} else if i == 0 {
|
||||
write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
|
||||
} else {
|
||||
write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if settings.show_num_parameters() {
|
||||
let num_params = self.num_params();
|
||||
if num_params > 0 {
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{indent}params: {}", num_params).unwrap();
|
||||
} else {
|
||||
write!(result, ", params: {}", num_params).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if settings.new_line_after_attribute() {
|
||||
write!(result, "{indent_close_braces}}}").unwrap();
|
||||
} else {
|
||||
write!(result, "}}").unwrap();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Custom display settings for the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional display settings object.
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Custom attributes for the module.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `_content` - The content object that contains display settings and attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional content object containing the custom attributes.
|
||||
fn custom_content(&self, _content: Content) -> Option<Content> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom module display settings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DisplaySettings {
|
||||
/// Whether to print the module parameter ids.
|
||||
show_param_id: Option<bool>,
|
||||
|
||||
/// Whether to print the module attributes.
|
||||
show_all_attributes: Option<bool>,
|
||||
|
||||
/// Whether to print the module number of parameters.
|
||||
show_num_parameters: Option<bool>,
|
||||
|
||||
/// Print new line after an attribute.
|
||||
new_line_after_attribute: Option<bool>,
|
||||
|
||||
/// Indentation size.
|
||||
indentation_size: Option<usize>,
|
||||
|
||||
/// Level of indentation.
|
||||
level: usize,
|
||||
}
|
||||
|
||||
impl Default for DisplaySettings {
|
||||
fn default() -> Self {
|
||||
DisplaySettings {
|
||||
show_param_id: None,
|
||||
show_all_attributes: None,
|
||||
show_num_parameters: None,
|
||||
new_line_after_attribute: None,
|
||||
indentation_size: None,
|
||||
level: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplaySettings {
|
||||
/// Create a new format settings.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new instance of `DisplaySettings`.
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// Sets a flag to show module parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_param_id(mut self, flag: bool) -> Self {
|
||||
self.show_param_id = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to show module attributes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show all module attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
|
||||
self.show_all_attributes = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to show the number of module parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show the number of module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
|
||||
self.show_num_parameters = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to print a new line after an attribute.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to print a new line after an attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
|
||||
self.new_line_after_attribute = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the indentation size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_indentation_size(mut self, size: usize) -> Self {
|
||||
self.indentation_size = Some(size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Inherits settings from the provided settings and return a new settings object.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `top` - The top level `DisplaySettings` to inherit from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn inherit(self, top: Self) -> Self {
|
||||
let mut updated = self.clone();
|
||||
|
||||
if let Some(show_param_id) = top.show_param_id {
|
||||
updated.show_param_id = Some(show_param_id);
|
||||
};
|
||||
|
||||
if let Some(show_all_attributes) = top.show_all_attributes {
|
||||
updated.show_all_attributes = Some(show_all_attributes);
|
||||
}
|
||||
|
||||
if let Some(show_num_parameters) = top.show_num_parameters {
|
||||
updated.show_num_parameters = Some(show_num_parameters);
|
||||
}
|
||||
|
||||
if let Some(new_line_after_attribute) = top.new_line_after_attribute {
|
||||
updated.new_line_after_attribute = Some(new_line_after_attribute);
|
||||
}
|
||||
|
||||
if let Some(indentation_size) = top.indentation_size {
|
||||
updated.indentation_size = Some(indentation_size);
|
||||
}
|
||||
|
||||
updated.level = top.level;
|
||||
|
||||
updated
|
||||
}
|
||||
|
||||
/// A convenience method to wrap the DisplaySettings struct in an option.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional `DisplaySettings`.
|
||||
pub fn optional(self) -> Option<Self> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
/// Increases the level of indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance with increased indentation level.
|
||||
pub fn level_up(mut self) -> Self {
|
||||
self.level += 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets `show_param_id` flag, substitutes false if not set.
|
||||
///
|
||||
/// This flag is used to print the module parameter ids.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show parameter ids.
|
||||
pub fn show_param_id(&self) -> bool {
|
||||
self.show_param_id.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Gets `show_all_attributes`, substitutes false if not set.
|
||||
///
|
||||
/// This flag is used to force to print all module attributes, overriding custom attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show all attributes.
|
||||
pub fn show_all_attributes(&self) -> bool {
|
||||
self.show_all_attributes.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Gets `show_num_parameters`, substitutes true if not set.
|
||||
///
|
||||
/// This flag is used to print the number of module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show the number of parameters.
|
||||
pub fn show_num_parameters(&self) -> bool {
|
||||
self.show_num_parameters.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Gets `new_line_after_attribute`, substitutes true if not set.
|
||||
///
|
||||
/// This flag is used to print a new line after an attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to print a new line after an attribute.
|
||||
pub fn new_line_after_attribute(&self) -> bool {
|
||||
self.new_line_after_attribute.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Gets `indentation_size`, substitutes 2 if not set.
|
||||
///
|
||||
/// This flag is used to set the size of indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An integer value indicating the size of indentation.
|
||||
pub fn indentation_size(&self) -> usize {
|
||||
self.indentation_size.unwrap_or(2)
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct to store the attributes of a module for formatting.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Content {
|
||||
/// List of attributes.
|
||||
pub attributes: Vec<Attribute>,
|
||||
|
||||
/// Single item content.
|
||||
pub single_item: Option<String>,
|
||||
|
||||
/// Display settings.
|
||||
pub display_settings: DisplaySettings,
|
||||
|
||||
/// Top level type name.
|
||||
pub top_level_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
/// Creates a new attributes struct.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `display_settings` - Display settings for the content.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new instance of `Content`.
|
||||
pub fn new(display_settings: DisplaySettings) -> Self {
|
||||
Content {
|
||||
attributes: Vec::new(),
|
||||
single_item: None,
|
||||
display_settings,
|
||||
top_level_type: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds an attribute to the format settings. The value will be formatted and stored as a string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Name of the attribute.
|
||||
/// * `value` - Value of the attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the new attribute added.
|
||||
pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
|
||||
if self.single_item.is_some() {
|
||||
panic!("Cannot add multiple attributes when single item is set.");
|
||||
}
|
||||
|
||||
let attribute = Attribute {
|
||||
name: name.to_owned(),
|
||||
value: value.format(self.display_settings.clone()), // TODO level + 1
|
||||
ty: any::type_name::<T>().to_string(),
|
||||
};
|
||||
self.attributes.push(attribute);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a single item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - Rendered string of the single item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the single item added.
|
||||
pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
|
||||
if !self.attributes.is_empty() {
|
||||
panic!("Cannot add single item when attributes are set.");
|
||||
}
|
||||
|
||||
self.single_item = Some(value.format(self.display_settings.clone()));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a single item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - Formatted display value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the formatted single item added.
|
||||
pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
|
||||
if !self.attributes.is_empty() {
|
||||
panic!("Cannot add single item when attributes are set.");
|
||||
}
|
||||
|
||||
self.single_item = Some(format!("{}", value));
|
||||
self
|
||||
}
|
||||
|
||||
/// A convenience method to wrap the Attributes struct in an option
|
||||
/// because it is often used as an optional field.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional `Content`.
|
||||
pub fn optional(self) -> Option<Self> {
|
||||
if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the top level type name.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ty` - The type name to set.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the top level type name set.
|
||||
pub fn set_top_level_type(mut self, ty: &str) -> Self {
|
||||
self.top_level_type = Some(ty.to_owned());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Attribute to print in the display method.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Attribute {
|
||||
/// Name of the attribute.
|
||||
pub name: String,
|
||||
|
||||
/// Value of the attribute.
|
||||
pub value: String,
|
||||
|
||||
/// Type of the attribute.
|
||||
pub ty: String,
|
||||
}
|
||||
|
||||
/// Extracts the short name of a type T
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A string slice representing the short name of the type.
|
||||
pub fn extract_type_name<T: ?Sized>() -> &'static str {
|
||||
// Get the full type name of T, including module path and generic parameters
|
||||
let ty = any::type_name::<T>();
|
||||
|
||||
// Find the first occurrence of '<' in the full type name
|
||||
// If not found, use the length of the type name
|
||||
let end = ty.find('<').unwrap_or(ty.len());
|
||||
|
||||
// Slice the type name up to the first '<' or the end
|
||||
let ty = &ty[0..end];
|
||||
|
||||
// Find the last occurrence of "::" in the sliced type name
|
||||
// If found, add 2 to skip the "::" itself
|
||||
// If not found, start from the beginning of the type name
|
||||
let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
|
||||
|
||||
// Find the last occurrence of '<' in the sliced type name
|
||||
// If not found, use the length of the type name
|
||||
let end = ty.rfind('<').unwrap_or(ty.len());
|
||||
|
||||
// If the start index is less than the end index,
|
||||
// return the slice of the type name from start to end
|
||||
// Otherwise, return the entire sliced type name
|
||||
if start < end {
|
||||
&ty[start..end]
|
||||
} else {
|
||||
ty
|
||||
}
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
mod base;
|
||||
mod display;
|
||||
mod param;
|
||||
|
||||
pub use base::*;
|
||||
pub use display::*;
|
||||
pub use param::*;
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
use alloc::{format, string::ToString};
|
||||
use core::{fmt::Display, marker::PhantomData};
|
||||
|
||||
use crate::{
|
||||
self as burn,
|
||||
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
|
||||
module::{
|
||||
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
|
||||
ModuleMapper, ModuleVisitor,
|
||||
},
|
||||
record::Record,
|
||||
};
|
||||
use burn::record::PrecisionSettings;
|
||||
|
@ -8,7 +14,6 @@ use burn_tensor::{
|
|||
backend::{AutodiffBackend, Backend},
|
||||
BasicAutodiffOps, BasicOps, Tensor,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Record used for constant type implementing the [module](crate::module::Module) trait.
|
||||
#[derive(Debug, Clone, Copy, new, Default)]
|
||||
|
@ -96,6 +101,15 @@ macro_rules! constant {
|
|||
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
|
||||
constant!(ad_module, $type);
|
||||
}
|
||||
|
||||
impl burn::module::ModuleDisplayDefault for $type {
|
||||
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
let string = format!("{}", self);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl burn::module::ModuleDisplay for $type {}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -122,6 +136,13 @@ constant!(i32);
|
|||
constant!(i16);
|
||||
constant!(i8);
|
||||
|
||||
impl burn::module::ModuleDisplay for str {}
|
||||
impl burn::module::ModuleDisplayDefault for str {
|
||||
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
content.add_formatted(&self).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
||||
type Record = ConstantRecord;
|
||||
|
||||
|
@ -158,6 +179,15 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
|
||||
content.add_single(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
|
||||
for Tensor<B, D, K>
|
||||
{
|
||||
|
@ -200,6 +230,14 @@ impl<B: Backend> Module<B> for PhantomData<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
content.add_single(&"PhantomData".to_string()).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PhantomData<B> {}
|
||||
|
||||
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
|
||||
type InnerModule = PhantomData<B::InnerBackend>;
|
||||
|
||||
|
@ -248,6 +286,27 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> ModuleDisplayDefault for Ignored<T>
|
||||
where
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
// For now, just print the debug representation of the ignored value
|
||||
content.add_single(&format!("{:?}", self.0)).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
|
||||
|
||||
impl<T> Display for Ignored<T>
|
||||
where
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "{:?}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
|
||||
use alloc::vec::Vec;
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor,
|
||||
};
|
||||
|
||||
use alloc::{format, vec::Vec};
|
||||
|
||||
use burn_tensor::backend::{AutodiffBackend, Backend};
|
||||
use core::fmt::Debug;
|
||||
|
||||
|
@ -52,6 +57,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
match self {
|
||||
Some(module) => content.add_single(module).optional(),
|
||||
None => content.add_single("None").optional(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Option<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
|
@ -128,6 +144,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
self.iter()
|
||||
.enumerate()
|
||||
.fold(content, |acc, (i, module)| {
|
||||
let index = format!("{}", i);
|
||||
acc.add(&index, module)
|
||||
})
|
||||
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Vec<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
|
@ -197,6 +228,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
self.iter()
|
||||
.enumerate()
|
||||
.fold(content, |acc, (i, module)| {
|
||||
let index = format!("{}", i);
|
||||
acc.add(&index, module)
|
||||
})
|
||||
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
|
||||
|
||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
|
||||
|
@ -269,6 +315,21 @@ macro_rules! impl_module_tuple {
|
|||
($(self.$i.valid(),)*)
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
|
||||
where
|
||||
$($l: ModuleDisplay,)*
|
||||
{
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let content = content
|
||||
$(.add(&format!("{}", $i), &self.$i))*
|
||||
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
|
||||
content.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
use super::ParamId;
|
||||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor, Param,
|
||||
};
|
||||
|
||||
use alloc::string::ToString;
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
|
@ -45,6 +51,24 @@ pub struct RunningState<V> {
|
|||
value: Arc<Mutex<V>>,
|
||||
}
|
||||
|
||||
// Implement display for the module
|
||||
|
||||
impl<V> core::fmt::Display for RunningState<V> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "RunningState(id={})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> ModuleDisplayDefault for RunningState<V> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add_formatted(&"RunningState".to_string())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> ModuleDisplay for RunningState<V> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use super::{Param, ParamId, Parameter};
|
||||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor,
|
||||
};
|
||||
use crate::tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{format, string::ToString, vec::Vec};
|
||||
use burn_tensor::{Bool, Data, Float, Int};
|
||||
|
||||
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
|
||||
|
@ -147,6 +150,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
|
||||
self.shape().dims
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
||||
type Record = Param<Tensor<B, D, Int>>;
|
||||
|
||||
|
@ -198,6 +217,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
|
||||
self.shape().dims
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
||||
type Record = Param<Tensor<B, D, Bool>>;
|
||||
|
||||
|
@ -249,6 +284,24 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
|
||||
self.shape().dims
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
use alloc::format;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::{Initializer, PaddingConfig1d};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::conv1d;
|
||||
use crate::tensor::ops::ConvOptions;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param},
|
||||
nn::{conv::checks, Initializer, PaddingConfig1d},
|
||||
tensor::{backend::Backend, module::conv1d, ops::ConvOptions, Tensor},
|
||||
};
|
||||
|
||||
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
|
@ -45,6 +44,7 @@ pub struct Conv1dConfig {
|
|||
///
|
||||
/// Should be created with [Conv1dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv1d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
|
||||
pub weight: Param<Tensor<B, 3>>,
|
||||
|
@ -54,7 +54,28 @@ pub struct Conv1d<B: Backend> {
|
|||
kernel_size: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
padding: PaddingConfig1d,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv1d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
content
|
||||
.add("stride", &self.stride)
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("dilation", &self.dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Conv1dConfig {
|
||||
|
@ -87,7 +108,7 @@ impl Conv1dConfig {
|
|||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use alloc::format;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -45,6 +46,7 @@ pub struct Conv2dConfig {
|
|||
///
|
||||
/// Should be created with [Conv2dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
|
@ -54,7 +56,7 @@ pub struct Conv2d<B: Backend> {
|
|||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
padding: PaddingConfig2d,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
impl Conv2dConfig {
|
||||
|
@ -93,12 +95,38 @@ impl Conv2dConfig {
|
|||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
|
||||
content
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Conv2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, Tensor};
|
||||
|
||||
|
@ -21,6 +21,7 @@ pub struct DropoutConfig {
|
|||
///
|
||||
/// Should be created with [DropoutConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Dropout {
|
||||
prob: f64,
|
||||
}
|
||||
|
@ -54,6 +55,18 @@ impl Dropout {
|
|||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Dropout {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
content.add("prob", &self.prob).optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate as burn;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::ModuleDisplay;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
|
@ -30,6 +32,7 @@ pub struct LinearConfig {
|
|||
///
|
||||
/// `O = IW + b`
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Linear<B: Backend> {
|
||||
/// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
|
||||
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
|
||||
|
@ -83,6 +86,23 @@ impl<B: Backend> Linear<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Linear<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
let [d_input, d_output] = self.weight.shape().dims;
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_output", &d_output)
|
||||
.add("bias", &self.bias.is_some())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate as burn;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
|
||||
use crate::nn::Initializer;
|
||||
use crate::{
|
||||
|
@ -33,6 +34,7 @@ pub struct BatchNormConfig {
|
|||
///
|
||||
/// Should be created using [BatchNormConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BatchNorm<B: Backend, const D: usize> {
|
||||
/// The learnable weight gamma.
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
|
@ -183,6 +185,24 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [num_features] = self.beta.shape().dims;
|
||||
|
||||
content
|
||||
.add("num_features", &num_features)
|
||||
.add("momentum", &self.momentum)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg(test)]
|
||||
mod tests_1d {
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::Module;
|
||||
use crate::module::ModuleDisplay;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -29,6 +30,7 @@ pub struct LayerNormConfig {
|
|||
///
|
||||
/// Should be created using [LayerNormConfig](LayerNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct LayerNorm<B: Backend> {
|
||||
/// The learnable weight.
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
|
@ -71,6 +73,22 @@ impl<B: Backend> LayerNorm<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for LayerNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
let [d_model] = self.gamma.shape().dims;
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -3,10 +3,9 @@ use crate as burn;
|
|||
use crate::tensor::ops::conv::calculate_conv_padding;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
|
||||
/// Padding configuration for 1D operators.
|
||||
#[derive(Module, Config, Debug, PartialEq)]
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig1d {
|
||||
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
|
||||
/// the same as the input.
|
||||
|
@ -34,7 +33,7 @@ impl PaddingConfig1d {
|
|||
}
|
||||
|
||||
/// Padding configuration for 2D operators.
|
||||
#[derive(Module, Config, Debug, PartialEq)]
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig2d {
|
||||
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
|
||||
/// the same as the input.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -43,7 +43,7 @@ pub struct AvgPool1dConfig {
|
|||
pub struct AvgPool1d {
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: PaddingConfig1d,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ impl AvgPool1dConfig {
|
|||
AvgPool1d {
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
count_include_pad: self.count_include_pad,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -42,7 +42,7 @@ pub struct AvgPool2dConfig {
|
|||
pub struct AvgPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: PaddingConfig2d,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ impl AvgPool2dConfig {
|
|||
AvgPool2d {
|
||||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
count_include_pad: self.count_include_pad,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -31,7 +31,7 @@ pub struct MaxPool1dConfig {
|
|||
pub struct MaxPool1d {
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: PaddingConfig1d,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
dilation: usize,
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ impl MaxPool1dConfig {
|
|||
MaxPool1d {
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -31,7 +31,7 @@ pub struct MaxPool2dConfig {
|
|||
pub struct MaxPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: PaddingConfig2d,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
dilation: [usize; 2],
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ impl MaxPool2dConfig {
|
|||
MaxPool2d {
|
||||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::UnfoldOptions;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use crate::tensor::module::unfold4d;
|
||||
use crate::module::{Ignored, Module};
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::module::unfold4d;
|
||||
use burn_tensor::ops::UnfoldOptions;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
|
@ -29,14 +28,14 @@ pub struct Unfold4dConfig {
|
|||
/// Should be created with [Unfold4dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Unfold4d {
|
||||
config: Unfold4dConfig,
|
||||
config: Ignored<Unfold4dConfig>,
|
||||
}
|
||||
|
||||
impl Unfold4dConfig {
|
||||
/// Initializes a new [Unfold4d] module.
|
||||
pub fn init(&self) -> Unfold4d {
|
||||
Unfold4d {
|
||||
config: self.clone(),
|
||||
config: Ignored(self.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +47,7 @@ impl Unfold4d {
|
|||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// input: `[batch_size, channels_in, height, width]`
|
||||
/// input: `[batch_size, channels_in, height, width]`
|
||||
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
|
||||
unfold4d(
|
||||
|
|
|
@ -370,6 +370,6 @@ mod tests {
|
|||
|
||||
// Compare the lengths of expected and actual serialized strings because
|
||||
// the order of the fields is not guaranteed for HashMaps.
|
||||
assert_eq!(serialized_str.len(), 134);
|
||||
assert_eq!(serialized_str.len(), 108);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,6 +49,12 @@ pub(crate) fn codegen_block(
|
|||
pub(crate) struct Codegen {
|
||||
pub tokens: proc_macro2::TokenStream,
|
||||
pub is_comptime: bool,
|
||||
pub array_indexing: Option<ArrayIndexing>,
|
||||
}
|
||||
|
||||
pub(crate) struct ArrayIndexing {
|
||||
pub array: proc_macro2::TokenStream,
|
||||
pub index: proc_macro2::TokenStream,
|
||||
}
|
||||
|
||||
impl From<proc_macro2::TokenStream> for Codegen {
|
||||
|
@ -56,6 +62,7 @@ impl From<proc_macro2::TokenStream> for Codegen {
|
|||
Self {
|
||||
tokens,
|
||||
is_comptime: false,
|
||||
array_indexing: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -65,6 +72,7 @@ impl Codegen {
|
|||
Self {
|
||||
tokens: tokens.into(),
|
||||
is_comptime,
|
||||
array_indexing: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ pub(crate) fn codegen_expr(
|
|||
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker),
|
||||
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker),
|
||||
_ => {
|
||||
let mut array_indexing = None;
|
||||
let tokens = match expr {
|
||||
syn::Expr::Path(path) => {
|
||||
return codegen_path_var(path, loop_level, variable_tracker)
|
||||
|
@ -50,7 +51,11 @@ pub(crate) fn codegen_expr(
|
|||
syn::Expr::MethodCall(call) => {
|
||||
codegen_expr_method_call(call, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_tracker),
|
||||
syn::Expr::Index(index) => {
|
||||
let codegen = codegen_index(index, loop_level, variable_tracker);
|
||||
array_indexing = codegen.array_indexing;
|
||||
codegen.tokens
|
||||
}
|
||||
syn::Expr::Array(array) => codegen_array_lit(array),
|
||||
syn::Expr::Reference(reference) => {
|
||||
codegen_ref(reference, loop_level, variable_tracker)
|
||||
|
@ -67,7 +72,9 @@ pub(crate) fn codegen_expr(
|
|||
}
|
||||
};
|
||||
|
||||
Codegen::new(tokens, false)
|
||||
let mut codegen = Codegen::new(tokens, false);
|
||||
codegen.array_indexing = array_indexing;
|
||||
codegen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,8 @@ pub(crate) fn codegen_binary(
|
|||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
let (lhs, is_comptime_lhs) = codegen_expr(&binary.left, loop_level, variable_tracker).split();
|
||||
let lhs = codegen_expr(&binary.left, loop_level, variable_tracker);
|
||||
let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing);
|
||||
let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split();
|
||||
|
||||
if is_comptime_lhs && is_comptime_rhs {
|
||||
|
@ -99,34 +100,94 @@ pub(crate) fn codegen_binary(
|
|||
burn_cube::frontend::eq::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::AddAssign(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
|
||||
syn::BinOp::AddAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
syn::BinOp::SubAssign(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
syn::BinOp::SubAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
syn::BinOp::MulAssign(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
syn::BinOp::MulAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
syn::BinOp::DivAssign(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
syn::BinOp::DivAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
syn::BinOp::And(_) => quote::quote! {
|
||||
{
|
||||
|
||||
|
|
|
@ -99,17 +99,25 @@ pub(crate) fn codegen_index(
|
|||
index: &syn::ExprIndex,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
) -> Codegen {
|
||||
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
|
||||
let index = codegen_expr(&index.index, loop_level, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
let tokens = quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
burn_cube::frontend::index::expand(context, _array, _index)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut codegen = Codegen::new(tokens, false);
|
||||
codegen.array_indexing = Some(super::base::ArrayIndexing {
|
||||
array: array.tokens,
|
||||
index: index.tokens,
|
||||
});
|
||||
|
||||
codegen
|
||||
}
|
||||
|
||||
/// Codegen for assignation
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::ir::{Elem, Item, Visibility};
|
||||
use crate::prelude::{ArrayExpand, CubeElem, KernelDefinition};
|
||||
use crate::prelude::KernelDefinition;
|
||||
use crate::KernelSettings;
|
||||
use crate::{
|
||||
frontend::{CubeContext, ExpandElement},
|
||||
|
@ -60,21 +60,21 @@ impl KernelBuilder {
|
|||
}
|
||||
|
||||
/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn output_array<T: CubeElem>(&mut self, item: Item) -> ArrayExpand<T> {
|
||||
pub fn output_array(&mut self, item: Item) -> ExpandElement {
|
||||
self.outputs.push(OutputInfo::Array { item });
|
||||
let variable = self.context.output_array(self.num_output, item);
|
||||
let variable = self.context.output(self.num_output, item);
|
||||
self.num_output += 1;
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn input_array<T: CubeElem>(&mut self, item: Item) -> ArrayExpand<T> {
|
||||
pub fn input_array(&mut self, item: Item) -> ExpandElement {
|
||||
self.inputs.push(InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
});
|
||||
let variable = self.context.input_array(self.num_input, item);
|
||||
let variable = self.context.input(self.num_input, item);
|
||||
self.num_input += 1;
|
||||
variable
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use alloc::rc::Rc;
|
|||
use core::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ArrayExpand, CubeElem, SharedMemoryExpand};
|
||||
use super::{CubeElem, SharedMemoryExpand};
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct VariablePool {
|
||||
|
@ -117,10 +117,8 @@ impl CubeContext {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn create_local_array<T: CubeElem>(&mut self, item: Item, size: u32) -> ArrayExpand<T> {
|
||||
ArrayExpand {
|
||||
val: ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size)),
|
||||
}
|
||||
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
|
||||
}
|
||||
|
||||
/// Obtain the index-th input
|
||||
|
@ -128,12 +126,6 @@ impl CubeContext {
|
|||
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
|
||||
}
|
||||
|
||||
pub fn input_array<T: CubeElem>(&mut self, index: u16, item: Item) -> ArrayExpand<T> {
|
||||
ArrayExpand {
|
||||
val: ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain the index-th output
|
||||
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
|
||||
let var = crate::ir::Variable::GlobalOutputArray(index, item);
|
||||
|
@ -141,14 +133,6 @@ impl CubeContext {
|
|||
ExpandElement::Plain(var)
|
||||
}
|
||||
|
||||
pub fn output_array<T: CubeElem>(&mut self, index: u16, item: Item) -> ArrayExpand<T> {
|
||||
let var = crate::ir::Variable::GlobalOutputArray(index, item);
|
||||
self.scope.borrow_mut().write_global_custom(var);
|
||||
ArrayExpand {
|
||||
val: ExpandElement::Plain(var),
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain the index-th scalar
|
||||
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))
|
||||
|
|
|
@ -8,44 +8,18 @@ use crate::{
|
|||
};
|
||||
use crate::{
|
||||
frontend::{indexation::Index, CubeContext, CubeElem},
|
||||
ir::Variable,
|
||||
prelude::{assign, index, index_assign, Comptime},
|
||||
};
|
||||
|
||||
use super::{ArgSettings, LaunchArg, TensorHandle, UInt};
|
||||
|
||||
use super::Init;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Array<T: CubeType> {
|
||||
_val: PhantomData<T>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ArrayExpand<T: CubeElem> {
|
||||
pub val: <T as CubeType>::ExpandType,
|
||||
}
|
||||
|
||||
impl<T: CubeElem> From<ArrayExpand<T>> for ExpandElement {
|
||||
fn from(array_expand: ArrayExpand<T>) -> Self {
|
||||
array_expand.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeElem> From<ArrayExpand<T>> for Variable {
|
||||
fn from(array_expand: ArrayExpand<T>) -> Self {
|
||||
*array_expand.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeElem> Init for ArrayExpand<T> {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeElem> CubeType for Array<T> {
|
||||
type ExpandType = ArrayExpand<T>;
|
||||
type ExpandType = ExpandElement;
|
||||
}
|
||||
|
||||
impl<T: CubeElem + Clone> Array<T> {
|
||||
|
@ -90,23 +64,24 @@ impl<T: CubeElem + Clone> Array<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: CubeElem> ArrayExpand<T> {
|
||||
impl ExpandElement {
|
||||
pub fn to_vectorized_expand(
|
||||
self,
|
||||
context: &mut CubeContext,
|
||||
vectorization_factor: UInt,
|
||||
) -> <T as CubeType>::ExpandType {
|
||||
) -> ExpandElement {
|
||||
let factor = vectorization_factor.val;
|
||||
let var = *self;
|
||||
let mut new_var = context.create_local(Item::vectorized(
|
||||
T::as_elem(),
|
||||
vectorization_factor.val as u8,
|
||||
var.item().elem(),
|
||||
factor as u8,
|
||||
));
|
||||
if vectorization_factor.val == 1 {
|
||||
let element = index::expand(context, self.val.clone(), 0);
|
||||
let element = index::expand(context, self.clone(), 0u32);
|
||||
assign::expand(context, element, new_var.clone());
|
||||
} else {
|
||||
for i in 0..factor {
|
||||
let element = index::expand(context, self.val.clone(), i);
|
||||
let element = index::expand(context, self.clone(), i);
|
||||
new_var = index_assign::expand(context, new_var, i, element);
|
||||
}
|
||||
}
|
||||
|
@ -124,11 +99,11 @@ impl<E: CubeType> Array<E> {
|
|||
impl<C: CubeElem> LaunchArg for Array<C> {
|
||||
type RuntimeArg<'a, R: Runtime> = ArrayHandle<'a, R>;
|
||||
|
||||
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ArrayExpand<C> {
|
||||
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
|
||||
builder.input_array(Item::vectorized(C::as_elem(), vectorization))
|
||||
}
|
||||
|
||||
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ArrayExpand<C> {
|
||||
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
|
||||
builder.output_array(Item::vectorized(C::as_elem(), vectorization))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -96,7 +96,10 @@ impl From<ExpandElement> for Variable {
|
|||
|
||||
impl Init for ExpandElement {
|
||||
fn init(self, context: &mut CubeContext) -> Self {
|
||||
init_expand(context, self, Operator::Assign)
|
||||
match *self {
|
||||
Variable::LocalArray(_, _, _, _) => self,
|
||||
_ => init_expand(context, self, Operator::Assign),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ pub trait Float:
|
|||
+ Erf
|
||||
+ Recip
|
||||
+ core::ops::Index<UInt, Output = Self>
|
||||
+ core::ops::IndexMut<UInt, Output = Self>
|
||||
{
|
||||
fn new(val: f32) -> Self;
|
||||
fn new_expand(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
|
||||
|
@ -109,6 +110,12 @@ macro_rules! impl_float {
|
|||
}
|
||||
}
|
||||
|
||||
impl core::ops::IndexMut<UInt> for $type {
|
||||
fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl LaunchArg for $type {
|
||||
type RuntimeArg<'a, R: Runtime> = $primitive;
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
use super::{CubeType, ExpandElement, Tensor, UInt};
|
||||
|
||||
pub trait Vectorized {
|
||||
|
|
|
@ -113,6 +113,90 @@ pub mod index {
|
|||
impl_index!(SharedMemory);
|
||||
}
|
||||
|
||||
pub mod add_assign_array_op {
|
||||
use crate::prelude::array_assign_binary_op_expand;
|
||||
|
||||
use self::ir::Operator;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn expand<
|
||||
Array: Into<ExpandElement>,
|
||||
Index: Into<ExpandElement>,
|
||||
Value: Into<ExpandElement>,
|
||||
>(
|
||||
context: &mut CubeContext,
|
||||
array: Array,
|
||||
index: Index,
|
||||
value: Value,
|
||||
) {
|
||||
array_assign_binary_op_expand(context, array, index, value, Operator::Add);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod sub_assign_array_op {
|
||||
use crate::prelude::array_assign_binary_op_expand;
|
||||
|
||||
use self::ir::Operator;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn expand<
|
||||
Array: Into<ExpandElement>,
|
||||
Index: Into<ExpandElement>,
|
||||
Value: Into<ExpandElement>,
|
||||
>(
|
||||
context: &mut CubeContext,
|
||||
array: Array,
|
||||
index: Index,
|
||||
value: Value,
|
||||
) {
|
||||
array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod mul_assign_array_op {
|
||||
use crate::prelude::array_assign_binary_op_expand;
|
||||
|
||||
use self::ir::Operator;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn expand<
|
||||
Array: Into<ExpandElement>,
|
||||
Index: Into<ExpandElement>,
|
||||
Value: Into<ExpandElement>,
|
||||
>(
|
||||
context: &mut CubeContext,
|
||||
array: Array,
|
||||
index: Index,
|
||||
value: Value,
|
||||
) {
|
||||
array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod div_assign_array_op {
|
||||
use crate::prelude::array_assign_binary_op_expand;
|
||||
|
||||
use self::ir::Operator;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn expand<
|
||||
Array: Into<ExpandElement>,
|
||||
Index: Into<ExpandElement>,
|
||||
Value: Into<ExpandElement>,
|
||||
>(
|
||||
context: &mut CubeContext,
|
||||
array: Array,
|
||||
index: Index,
|
||||
value: Value,
|
||||
) {
|
||||
array_assign_binary_op_expand(context, array, index, value, Operator::Div);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod add_assign_op {
|
||||
use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};
|
||||
|
||||
|
|
|
@ -203,3 +203,43 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization
|
|||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn array_assign_binary_op_expand<
|
||||
Array: Into<ExpandElement>,
|
||||
Index: Into<ExpandElement>,
|
||||
Value: Into<ExpandElement>,
|
||||
F: Fn(BinaryOperator) -> Operator,
|
||||
>(
|
||||
context: &mut CubeContext,
|
||||
array: Array,
|
||||
index: Index,
|
||||
value: Value,
|
||||
func: F,
|
||||
) {
|
||||
let array: ExpandElement = array.into();
|
||||
let index: ExpandElement = index.into();
|
||||
let value: ExpandElement = value.into();
|
||||
|
||||
let tmp = context.create_local(array.item());
|
||||
|
||||
let read = Operator::Index(BinaryOperator {
|
||||
lhs: *array,
|
||||
rhs: *index,
|
||||
out: *tmp,
|
||||
});
|
||||
let calculate = func(BinaryOperator {
|
||||
lhs: *tmp,
|
||||
rhs: *value,
|
||||
out: *tmp,
|
||||
});
|
||||
|
||||
let write = Operator::IndexAssign(BinaryOperator {
|
||||
lhs: *index,
|
||||
rhs: *tmp,
|
||||
out: *array,
|
||||
});
|
||||
|
||||
context.register(read);
|
||||
context.register(calculate);
|
||||
context.register(write);
|
||||
}
|
||||
|
|
|
@ -22,11 +22,21 @@ fn array_of_one_to_vectorized_variable<T: Numeric>() -> T {
|
|||
array.to_vectorized(Comptime::new(UInt::new(1)))
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn array_add_assign_simple(mut array: Array<UInt>) {
|
||||
array[UInt::new(1)] += UInt::new(1);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn array_add_assign_expr(mut array: Array<UInt>) {
|
||||
array[UInt::new(1) + UInt::new(5)] += UInt::new(1);
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
ir::{Item, Variable},
|
||||
ir::{Elem, Item, Variable},
|
||||
};
|
||||
|
||||
type ElemType = F32;
|
||||
|
@ -39,6 +49,20 @@ mod tests {
|
|||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
inline_macro_ref_read_write()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_add_assign() {
|
||||
let mut context = CubeContext::root();
|
||||
let array = context.input(0, Item::new(Elem::UInt));
|
||||
|
||||
array_add_assign_simple_expand(&mut context, array);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_array_add_assign_simple()
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -84,6 +108,37 @@ mod tests {
|
|||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_add_assign_expr() {
|
||||
let mut context = CubeContext::root();
|
||||
let array = context.input(0, Item::new(Elem::UInt));
|
||||
|
||||
array_add_assign_expr_expand(&mut context, array);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_array_add_assign_expr()
|
||||
);
|
||||
}
|
||||
|
||||
fn inline_macro_array_add_assign_simple() -> String {
|
||||
let context = CubeContext::root();
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let local = scope.create_local(Item::new(Elem::UInt));
|
||||
|
||||
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
|
||||
let index = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let value = Variable::ConstantScalar(1., Elem::UInt);
|
||||
|
||||
cpa!(scope, local = array[index]);
|
||||
cpa!(scope, local += value);
|
||||
cpa!(scope, array[index] = local);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
fn inline_macro_ref_to_vectorized() -> String {
|
||||
let context = CubeContext::root();
|
||||
let scalar_item = Item::new(ElemType::as_elem());
|
||||
|
@ -123,4 +178,24 @@ mod tests {
|
|||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
fn inline_macro_array_add_assign_expr() -> String {
|
||||
let context = CubeContext::root();
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let index = scope.create_local(Item::new(Elem::UInt));
|
||||
let local = scope.create_local(Item::new(Elem::UInt));
|
||||
|
||||
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
|
||||
let const1 = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let const2 = Variable::ConstantScalar(5., Elem::UInt);
|
||||
let value = Variable::ConstantScalar(1., Elem::UInt);
|
||||
|
||||
cpa!(scope, index = const1 + const2);
|
||||
cpa!(scope, local = array[index]);
|
||||
cpa!(scope, local += value);
|
||||
cpa!(scope, array[index] = local);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ pub(crate) mod record;
|
|||
pub(crate) mod shared;
|
||||
|
||||
/// Derive macro for the module.
|
||||
#[proc_macro_derive(Module)]
|
||||
#[proc_macro_derive(Module, attributes(module))]
|
||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
module::derive_impl(&input)
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::{display, record::ModuleRecordCodegen};
|
|||
use crate::shared::generics::GenericsHelper;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Generics};
|
||||
use syn::{parse_quote, Attribute, Generics};
|
||||
|
||||
/// Basic trait to be implemented for Module generation.
|
||||
pub(crate) trait ModuleCodegen {
|
||||
|
@ -30,8 +30,8 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
|
||||
let generics = GenericsParser::from_ast(&ast.generics);
|
||||
|
||||
let display_fn = display::display_fn(name);
|
||||
|
||||
let display_fn = display::display_fn(ast);
|
||||
let attributes_fn = display::attributes_fn(ast);
|
||||
let num_params_fn = codegen.gen_num_params();
|
||||
let visit = codegen.gen_visit();
|
||||
let map_mut = codegen.gen_map();
|
||||
|
@ -54,7 +54,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
|
||||
let generics_ty_inner_module = generics.inner_module_ty;
|
||||
|
||||
let gen = quote! {
|
||||
let mut gen = quote! {
|
||||
impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {
|
||||
type Record = #record_name #generics_ty_module;
|
||||
|
||||
|
@ -69,6 +69,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#collect_devices
|
||||
#to_device
|
||||
#fork
|
||||
|
||||
}
|
||||
|
||||
impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff
|
||||
|
@ -82,6 +83,15 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#display_fn
|
||||
}
|
||||
|
||||
|
||||
impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {
|
||||
#attributes_fn
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
burn::module::Module::num_params(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl #generics_module Clone for #name #generics_ty_module #generics_where_module {
|
||||
#clone_fn
|
||||
}
|
||||
|
@ -89,13 +99,21 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#record_type
|
||||
};
|
||||
|
||||
if !has_custom_display(&ast.attrs) {
|
||||
gen.extend(quote! {
|
||||
impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
gen
|
||||
}
|
||||
|
||||
// When there is no backend in the generic parameter, the type is considered as a constant.
|
||||
pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
let (_generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
|
||||
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
|
||||
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};
|
||||
|
@ -112,7 +130,10 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let (generics_module, _, _) = generics_module.split_for_impl();
|
||||
let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();
|
||||
|
||||
let gen = quote! {
|
||||
let display_fn = display::display_fn(ast);
|
||||
let attributes_fn = display::attributes_fn(ast);
|
||||
|
||||
let mut gen = quote! {
|
||||
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
|
||||
burn::constant!(module);
|
||||
}
|
||||
|
@ -121,8 +142,26 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
|||
for #name #generics_ty #generics_where {
|
||||
burn::constant!(ad_module, #name #generics_ty);
|
||||
}
|
||||
|
||||
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
|
||||
#display_fn
|
||||
}
|
||||
|
||||
|
||||
impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {
|
||||
#attributes_fn
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
if !has_custom_display(&ast.attrs) {
|
||||
gen.extend(quote! {
|
||||
impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
gen
|
||||
}
|
||||
|
||||
|
@ -159,22 +198,64 @@ impl GenericsParser {
|
|||
#ident: burn::module::Module<B>
|
||||
}
|
||||
);
|
||||
|
||||
module.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplayDefault
|
||||
}
|
||||
);
|
||||
|
||||
module.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::AutodiffModule<B>
|
||||
}
|
||||
);
|
||||
module_autodiff.add_predicate(
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::Module<B::InnerBackend>
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplayDefault
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
});
|
||||
|
||||
module.consts().into_iter().for_each(|ident| {
|
||||
|
@ -188,3 +269,18 @@ impl GenericsParser {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_custom_display(attrs: &[Attribute]) -> bool {
|
||||
attrs.iter().any(|attr| {
|
||||
attr.path().is_ident("module")
|
||||
&& attr
|
||||
.parse_nested_meta(|meta| {
|
||||
if meta.path.is_ident("custom_display") {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(meta.error("unsupported attribute"))
|
||||
}
|
||||
})
|
||||
.is_ok()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,11 +1,96 @@
|
|||
use proc_macro2::Ident;
|
||||
use quote::quote;
|
||||
|
||||
pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream {
|
||||
pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
|
||||
match &ast.data {
|
||||
syn::Data::Struct(ref data_struct) => {
|
||||
let fields = match &data_struct.fields {
|
||||
syn::Fields::Named(ref named_fields) => {
|
||||
named_fields.named.iter().collect::<Vec<_>>()
|
||||
}
|
||||
syn::Fields::Unit => Vec::new(),
|
||||
_ => panic!("attributes_fn only supports structs with named or unit fields"),
|
||||
};
|
||||
let field_prints = fields.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { .add(stringify!(#field_name), &self.#field_name) }
|
||||
});
|
||||
let struct_name = &ast.ident;
|
||||
quote! {
|
||||
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
content
|
||||
.set_top_level_type(&stringify!(#struct_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(ref data_enum) => {
|
||||
let variant_prints = data_enum.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
match &variant.fields {
|
||||
syn::Fields::Unit => {
|
||||
quote! {
|
||||
Self::#variant_name => {
|
||||
content.add_formatted(&stringify!(#variant_name).to_string())
|
||||
.optional()
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Fields::Named(ref named_fields) => {
|
||||
let field_prints = named_fields.named.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { .add(stringify!(#field_name), &self.#field_name) }
|
||||
});
|
||||
|
||||
let field_names = named_fields.named.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { #field_name }
|
||||
});
|
||||
|
||||
quote! {
|
||||
Self::#variant_name { #(#field_names),* } => {
|
||||
content.set_top_level_type(&stringify!(#variant_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Fields::Unnamed(ref unnamed_fields) => {
|
||||
let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {
|
||||
syn::Ident::new(&format!("_{}", i), proc_macro2::Span::call_site())
|
||||
});
|
||||
|
||||
let field_prints = field_names.clone().map(|field_name| {
|
||||
quote! { .add(stringify!(#field_name), #field_name) }
|
||||
});
|
||||
quote! {
|
||||
Self::#variant_name(#(#field_names),*) => {
|
||||
content.set_top_level_type(&stringify!(#variant_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
quote! {
|
||||
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
match self {
|
||||
#(#variant_prints)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => panic!("attributes_fn only supports structs and enums"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "{}[num_params={}]", stringify!(#name), self.num_params())
|
||||
|
||||
let formatted = burn::module::ModuleDisplay::format(self, Default::default());
|
||||
write!(f, "{}", formatted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,23 +148,25 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
|
||||
let client = lhs.client.clone();
|
||||
|
||||
let lhs = match lhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(lhs),
|
||||
false => lhs,
|
||||
};
|
||||
let rhs = match rhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(rhs),
|
||||
false => rhs,
|
||||
};
|
||||
let lhs = into_contiguous(lhs);
|
||||
let rhs = into_contiguous(rhs);
|
||||
// let lhs = match lhs.batch_swapped_with_row_col() {
|
||||
// true => into_contiguous(lhs),
|
||||
// false => lhs,
|
||||
// };
|
||||
// let rhs = match rhs.batch_swapped_with_row_col() {
|
||||
// true => into_contiguous(rhs),
|
||||
// false => rhs,
|
||||
// };
|
||||
|
||||
config.block_size_m = 16;
|
||||
config.block_size_n = 16;
|
||||
config.block_size_k = 16; // k must be <= both m and n
|
||||
config.block_size_m = 64;
|
||||
config.block_size_n = 64;
|
||||
config.block_size_k = 32; // k must be <= both m and n
|
||||
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
|
||||
|
||||
let vectorization_factor = 4;
|
||||
let x = (config.block_size_m / vectorization_factor) as u32;
|
||||
let y = (config.block_size_n / vectorization_factor) as u32;
|
||||
let x = (config.block_size_m / 4) as u32;
|
||||
let y = (config.block_size_n / 4) as u32;
|
||||
let settings = KernelSettings::default()
|
||||
.vectorize_input(0, vectorization_factor as u8)
|
||||
.vectorize_input(1, vectorization_factor as u8)
|
||||
|
|
|
@ -9,11 +9,12 @@ use super::{
|
|||
};
|
||||
|
||||
#[cube]
|
||||
#[allow(unused_mut)]
|
||||
pub(crate) fn compute_loop<F: Float>(
|
||||
coordinates: Coordinates,
|
||||
shared_lhs: SharedMemory<F>,
|
||||
shared_rhs: SharedMemory<F>,
|
||||
results: Array<F>,
|
||||
mut results: Array<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
|
@ -25,14 +26,13 @@ pub(crate) fn compute_loop<F: Float>(
|
|||
let unit_row = coordinates.unit_row;
|
||||
let unit_col = coordinates.unit_col;
|
||||
|
||||
let lhs_stride = block_size_m / tile_size;
|
||||
let rhs_stride = block_size_n / tile_size;
|
||||
|
||||
for dot_index in range(0u32, block_size_k, unroll) {
|
||||
let register_m = shared_lhs[(unit_col + dot_index) * Comptime::runtime(lhs_stride)];
|
||||
let register_n = shared_rhs[(unit_row + dot_index) * Comptime::runtime(rhs_stride)];
|
||||
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
|
||||
/ Comptime::runtime(tile_size)];
|
||||
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
|
||||
/ Comptime::runtime(tile_size)];
|
||||
|
||||
tile_outer_product(register_m, register_n, results, config);
|
||||
tile_outer_product::<F>(register_m, register_n, results, config);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,16 +51,19 @@ pub mod tests {
|
|||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_m);
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_m);
|
||||
let sm_size_lhs = block_size_m * block_size_k / tile_size;
|
||||
let sm_size_rhs = block_size_n * block_size_k / tile_size;
|
||||
|
||||
// Shared memories are not launchable, so we launch with tensor and convert to shared memory
|
||||
let sm_size_lhs = Comptime::map(config, |c| c.sm_size_lhs);
|
||||
let mut shared_lhs =
|
||||
SharedMemory::<F>::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size));
|
||||
for i in range(0u32, lhs.len(), Comptime::new(false)) {
|
||||
shared_lhs[i] = lhs[i];
|
||||
}
|
||||
|
||||
let sm_size_rhs = Comptime::map(config, |c| c.sm_size_rhs);
|
||||
let mut shared_rhs =
|
||||
SharedMemory::<F>::vectorized(Comptime::get(sm_size_rhs), Comptime::get(tile_size));
|
||||
for i in range(0u32, rhs.len(), Comptime::new(false)) {
|
||||
|
@ -135,4 +138,62 @@ pub mod tests {
|
|||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn compute_loop_unit_offset_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
||||
let tile_size = 4;
|
||||
let lhs = burn_tensor::Tensor::<B<R>, 2>::from_data(
|
||||
burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..32, device)
|
||||
.reshape([8, 4])
|
||||
.float()
|
||||
.transpose()
|
||||
.into_data(),
|
||||
device,
|
||||
)
|
||||
.into_primitive();
|
||||
let rhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..32, device)
|
||||
.reshape([4, 8])
|
||||
.float()
|
||||
.into_primitive();
|
||||
let client = R::client(device);
|
||||
|
||||
let unit_row = 4;
|
||||
let unit_col = 4;
|
||||
let results = client.empty(tile_size * tile_size * core::mem::size_of::<f32>());
|
||||
|
||||
// Unit test
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
let settings = KernelSettings::default()
|
||||
.cube_dim(CubeDim::new(1, 1, 1))
|
||||
.vectorize_input(0, tile_size as u8)
|
||||
.vectorize_input(1, tile_size as u8);
|
||||
|
||||
let mut tiling2d_config = Tiling2dConfig::default();
|
||||
tiling2d_config.block_size_m = 8;
|
||||
tiling2d_config.block_size_k = 8;
|
||||
tiling2d_config.block_size_n = 8;
|
||||
let config = CubeTiling2dConfig::new(tiling2d_config, 4, 8, 4, tile_size);
|
||||
|
||||
compute_loop_test_launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
unit_row,
|
||||
unit_col,
|
||||
ArrayHandle::new(&results, 1),
|
||||
config,
|
||||
);
|
||||
|
||||
let actual = client.read(results.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
let expected = &[
|
||||
1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0,
|
||||
1978.0, 1928.0, 2046.0, 2164.0, 2282.0,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,19 +25,12 @@ pub struct CubeTiling2dConfig {
|
|||
pub check_k_bounds: bool,
|
||||
/// Bounds must be checked on rhs dimension
|
||||
pub check_n_bounds: bool,
|
||||
/// Shared memory size lhs: technically derivable from others, but needs comptime arithmetic
|
||||
pub sm_size_lhs: UInt,
|
||||
/// Shared memory size rhs: technically derivable from others, but needs comptime arithmetic
|
||||
pub sm_size_rhs: UInt,
|
||||
/// Tile size. Should correspond to vectorization of inputs/outputs/shared memory
|
||||
pub tile_size: UInt,
|
||||
}
|
||||
|
||||
impl CubeTiling2dConfig {
|
||||
pub fn new(config: Tiling2dConfig, m: usize, k: usize, n: usize, tile_size: usize) -> Self {
|
||||
let sm_size_lhs = (config.block_size_m / tile_size) * config.block_size_k;
|
||||
let sm_size_rhs = (config.block_size_n / tile_size) * config.block_size_k;
|
||||
|
||||
CubeTiling2dConfig {
|
||||
block_size_m: UInt::new(config.block_size_m as u32),
|
||||
block_size_k: UInt::new(config.block_size_k as u32),
|
||||
|
@ -46,8 +39,6 @@ impl CubeTiling2dConfig {
|
|||
check_m_bounds: m % config.block_size_m != 0,
|
||||
check_k_bounds: k % config.block_size_k != 0,
|
||||
check_n_bounds: n % config.block_size_n != 0,
|
||||
sm_size_lhs: UInt::new(sm_size_lhs as u32),
|
||||
sm_size_rhs: UInt::new(sm_size_rhs as u32),
|
||||
tile_size: UInt::new(tile_size as u32),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,12 +33,12 @@ pub(crate) fn load_lhs_transposed<F: Float>(
|
|||
|
||||
let tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
load_tile(
|
||||
load_tile::<F>(
|
||||
lhs,
|
||||
tile,
|
||||
offset,
|
||||
unit_col,
|
||||
unit_row,
|
||||
unit_col,
|
||||
skip_row,
|
||||
skip_col,
|
||||
Comptime::map(config, |c| c.check_m_bounds),
|
||||
|
@ -46,7 +46,7 @@ pub(crate) fn load_lhs_transposed<F: Float>(
|
|||
config,
|
||||
);
|
||||
|
||||
write_tile_transposed(
|
||||
write_tile_transposed::<F>(
|
||||
tile,
|
||||
shared_lhs,
|
||||
sm_position_base,
|
||||
|
@ -83,7 +83,7 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
|
||||
let tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
load_tile(
|
||||
load_tile::<F>(
|
||||
rhs,
|
||||
tile,
|
||||
offset,
|
||||
|
@ -96,7 +96,7 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
config,
|
||||
);
|
||||
|
||||
write_tile_plain(
|
||||
write_tile_plain::<F>(
|
||||
tile,
|
||||
shared_rhs,
|
||||
sm_position_base,
|
||||
|
@ -128,16 +128,16 @@ fn load_tile<F: Float>(
|
|||
|
||||
if Comptime::get(check_vertical_bounds) {
|
||||
let row = skip_row + load_row;
|
||||
let dim_vertical = tensor.shape(tensor.rank() - UInt::new(2));
|
||||
let dim_vertical = tensor.shape(rank - UInt::new(2));
|
||||
|
||||
if Comptime::get(check_horizontal_bounds) {
|
||||
let col = skip_col + load_col;
|
||||
let dim_horizontal = tensor.shape(tensor.rank() - UInt::new(1));
|
||||
let dim_horizontal = tensor.shape(rank - UInt::new(1));
|
||||
|
||||
if col >= dim_horizontal {
|
||||
read_zeros(tile, tile_size, unroll);
|
||||
read_zeros::<F>(tile, tile_size, unroll);
|
||||
} else {
|
||||
read_partial(
|
||||
read_partial::<F>(
|
||||
tensor,
|
||||
dim_vertical,
|
||||
row,
|
||||
|
@ -148,7 +148,7 @@ fn load_tile<F: Float>(
|
|||
);
|
||||
}
|
||||
} else {
|
||||
read_partial(
|
||||
read_partial::<F>(
|
||||
tensor,
|
||||
dim_vertical,
|
||||
row,
|
||||
|
@ -161,11 +161,11 @@ fn load_tile<F: Float>(
|
|||
} else {
|
||||
if Comptime::get(check_horizontal_bounds) {
|
||||
let col = skip_col + load_col;
|
||||
let dim_horizontal = tensor.shape(tensor.rank() - UInt::new(1));
|
||||
let dim_horizontal = tensor.shape(rank - UInt::new(1));
|
||||
if col >= dim_horizontal {
|
||||
read_zeros(tile, tile_size, unroll);
|
||||
read_zeros::<F>(tile, tile_size, unroll);
|
||||
} else {
|
||||
read_whole(
|
||||
read_whole::<F>(
|
||||
tensor,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
|
@ -175,7 +175,7 @@ fn load_tile<F: Float>(
|
|||
);
|
||||
}
|
||||
} else {
|
||||
read_whole(
|
||||
read_whole::<F>(
|
||||
tensor,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
|
@ -221,12 +221,16 @@ fn write_tile_transposed<F: Float>(
|
|||
shared_memory[sm_position_base] = tile[0];
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let mut transposed = Array::<F>::new(Comptime::get(tile_size));
|
||||
let mut transposed = F::vectorized(0., Comptime::get(tile_size));
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
transposed[j] = tile[j][i];
|
||||
}
|
||||
let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization;
|
||||
shared_memory[sm_position] = transposed.to_vectorized(tile_size);
|
||||
shared_memory[sm_position] = transposed;
|
||||
|
||||
// let mut x = F::vectorized(0., Comptime::get(tile_size));
|
||||
// x[UInt::new(0)] = F::cast_from(UNIT_POS);
|
||||
// shared_memory[UNIT_POS] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -286,7 +290,7 @@ pub mod tests {
|
|||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn read_whole_test<F: Float>(tensor: Tensor<F>, mut tile: Array<F>, tile_size: Comptime<UInt>) {
|
||||
read_whole(
|
||||
read_whole::<F>(
|
||||
tensor,
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
|
@ -303,7 +307,7 @@ pub mod tests {
|
|||
mut tile: Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
) {
|
||||
read_partial(
|
||||
read_partial::<F>(
|
||||
tensor,
|
||||
Comptime::runtime(tile_size),
|
||||
UInt::new(2),
|
||||
|
@ -317,7 +321,7 @@ pub mod tests {
|
|||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn read_zeros_test<F: Float>(mut tile: Array<F>, tile_size: Comptime<UInt>) {
|
||||
read_zeros(tile, tile_size, Comptime::new(true))
|
||||
read_zeros::<F>(tile, tile_size, Comptime::new(true))
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
|
@ -361,7 +365,7 @@ pub mod tests {
|
|||
|
||||
let sm_stride = block_size_m;
|
||||
let sm_size = Comptime::runtime(block_size_k * block_size_m);
|
||||
let shared_memory = SharedMemory::vectorized(sm_size, Comptime::get(tile_size));
|
||||
let shared_memory = SharedMemory::<F>::vectorized(sm_size, Comptime::get(tile_size));
|
||||
|
||||
if Comptime::get(transposed) {
|
||||
write_tile_transposed(
|
||||
|
@ -424,6 +428,42 @@ pub mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn load_tensor_multiple_tiles_test<F: Float>(
|
||||
tensor: Tensor<F>,
|
||||
mut sm_out: Array<F>,
|
||||
k: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
is_lhs: Comptime<bool>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let sm_size = block_size_k * block_size_m / tile_size;
|
||||
let shared_memory =
|
||||
SharedMemory::<F>::vectorized(Comptime::get(sm_size), Comptime::get(tile_size));
|
||||
|
||||
let unit_row = UInt::new(4) * UNIT_POS_X;
|
||||
let unit_col = UInt::new(4) * UNIT_POS_Y;
|
||||
let offset = UInt::new(0);
|
||||
|
||||
let coordinates = Coordinates {
|
||||
unit_row,
|
||||
unit_col,
|
||||
skip_row: UInt::new(0),
|
||||
skip_col: UInt::new(0),
|
||||
};
|
||||
if Comptime::get(is_lhs) {
|
||||
load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config);
|
||||
} else {
|
||||
load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config);
|
||||
}
|
||||
|
||||
for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) {
|
||||
sm_out[i] = shared_memory[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
@ -817,6 +857,112 @@ pub mod tests {
|
|||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_lhs_transposed_cube_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
||||
let tile_size = 4;
|
||||
let lhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..64, device)
|
||||
.reshape([8, 8])
|
||||
.float()
|
||||
.into_primitive();
|
||||
let client = R::client(device);
|
||||
|
||||
// Unit test
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
let settings = KernelSettings::default()
|
||||
.cube_dim(CubeDim::new(2, 2, 1))
|
||||
.vectorize_input(0, tile_size as u8)
|
||||
.vectorize_output(0, tile_size as u8);
|
||||
|
||||
let mut tiling2d_config = Tiling2dConfig::default();
|
||||
tiling2d_config.block_size_m = 8;
|
||||
tiling2d_config.block_size_k = 8;
|
||||
tiling2d_config.block_size_n = 8;
|
||||
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 8, 8, 8, tile_size);
|
||||
|
||||
let sm_out = client.empty(
|
||||
tiling2d_config.block_size_k
|
||||
* tiling2d_config.block_size_m
|
||||
* core::mem::size_of::<f32>(),
|
||||
);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
ArrayHandle::new(&sm_out, 64),
|
||||
0,
|
||||
config,
|
||||
true,
|
||||
);
|
||||
|
||||
let actual = client.read(sm_out.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
let expected = &[
|
||||
0.0, 8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0,
|
||||
57.0, 2.0, 10.0, 18.0, 26.0, 34.0, 42.0, 50.0, 58.0, 3.0, 11.0, 19.0, 27.0, 35.0, 43.0,
|
||||
51.0, 59.0, 4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 5.0, 13.0, 21.0, 29.0, 37.0,
|
||||
45.0, 53.0, 61.0, 6.0, 14.0, 22.0, 30.0, 38.0, 46.0, 54.0, 62.0, 7.0, 15.0, 23.0, 31.0,
|
||||
39.0, 47.0, 55.0, 63.0,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_lhs_transposed_offset_cube_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
||||
let tile_size = 4;
|
||||
let lhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..128, device)
|
||||
.reshape([8, 16])
|
||||
.float()
|
||||
.into_primitive();
|
||||
let client = R::client(device);
|
||||
|
||||
// Unit test
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
let settings = KernelSettings::default()
|
||||
.cube_dim(CubeDim::new(2, 2, 1))
|
||||
.vectorize_input(0, tile_size as u8)
|
||||
.vectorize_output(0, tile_size as u8);
|
||||
|
||||
let mut tiling2d_config = Tiling2dConfig::default();
|
||||
tiling2d_config.block_size_m = 8;
|
||||
tiling2d_config.block_size_k = 8;
|
||||
tiling2d_config.block_size_n = 8;
|
||||
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 8, 8, 16, tile_size);
|
||||
|
||||
let sm_out = client.empty(
|
||||
tiling2d_config.block_size_k
|
||||
* tiling2d_config.block_size_m
|
||||
* core::mem::size_of::<f32>(),
|
||||
);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
ArrayHandle::new(&sm_out, 64),
|
||||
8,
|
||||
config,
|
||||
true,
|
||||
);
|
||||
|
||||
let actual = client.read(sm_out.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
let expected = &[
|
||||
8.0, 24.0, 40.0, 56.0, 72.0, 88.0, 104.0, 120.0, 9.0, 25.0, 41.0, 57.0, 73.0, 89.0,
|
||||
105.0, 121.0, 10.0, 26.0, 42.0, 58.0, 74.0, 90.0, 106.0, 122.0, 11.0, 27.0, 43.0, 59.0,
|
||||
75.0, 91.0, 107.0, 123.0, 12.0, 28.0, 44.0, 60.0, 76.0, 92.0, 108.0, 124.0, 13.0, 29.0,
|
||||
45.0, 61.0, 77.0, 93.0, 109.0, 125.0, 14.0, 30.0, 46.0, 62.0, 78.0, 94.0, 110.0, 126.0,
|
||||
15.0, 31.0, 47.0, 63.0, 79.0, 95.0, 111.0, 127.0,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_rhs_plain_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
@ -868,4 +1014,110 @@ pub mod tests {
|
|||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_rhs_plain_cube_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
||||
let tile_size = 4;
|
||||
let rhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..64, device)
|
||||
.reshape([8, 8])
|
||||
.float()
|
||||
.into_primitive();
|
||||
let client = R::client(device);
|
||||
|
||||
// Unit test
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
let settings = KernelSettings::default()
|
||||
.cube_dim(CubeDim::new(2, 2, 1))
|
||||
.vectorize_input(0, tile_size as u8)
|
||||
.vectorize_output(0, tile_size as u8);
|
||||
|
||||
let mut tiling2d_config = Tiling2dConfig::default();
|
||||
tiling2d_config.block_size_m = 8;
|
||||
tiling2d_config.block_size_k = 8;
|
||||
tiling2d_config.block_size_n = 8;
|
||||
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 8, 8, 8, tile_size);
|
||||
|
||||
let sm_out = client.empty(
|
||||
tiling2d_config.block_size_k
|
||||
* tiling2d_config.block_size_m
|
||||
* core::mem::size_of::<f32>(),
|
||||
);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
ArrayHandle::new(&sm_out, 64),
|
||||
0,
|
||||
config,
|
||||
false,
|
||||
);
|
||||
|
||||
let actual = client.read(sm_out.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
let expected = &[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
|
||||
44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
|
||||
58.0, 59.0, 60.0, 61.0, 62.0, 63.0,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_rhs_plain_cube_offset_test<R: JitRuntime>(device: &R::Device) {
|
||||
pub type B<R> = JitBackend<R, f32, i32>;
|
||||
|
||||
let tile_size = 4;
|
||||
let rhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..128, device)
|
||||
.reshape([16, 8])
|
||||
.float()
|
||||
.into_primitive();
|
||||
let client = R::client(device);
|
||||
|
||||
// Unit test
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
let settings = KernelSettings::default()
|
||||
.cube_dim(CubeDim::new(2, 2, 1))
|
||||
.vectorize_input(0, tile_size as u8)
|
||||
.vectorize_output(0, tile_size as u8);
|
||||
|
||||
let mut tiling2d_config = Tiling2dConfig::default();
|
||||
tiling2d_config.block_size_m = 8;
|
||||
tiling2d_config.block_size_k = 8;
|
||||
tiling2d_config.block_size_n = 8;
|
||||
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 16, 16, 8, tile_size);
|
||||
|
||||
let sm_out = client.empty(
|
||||
tiling2d_config.block_size_k
|
||||
* tiling2d_config.block_size_m
|
||||
* core::mem::size_of::<f32>(),
|
||||
);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
ArrayHandle::new(&sm_out, 64),
|
||||
8,
|
||||
config,
|
||||
false,
|
||||
);
|
||||
|
||||
let actual = client.read(sm_out.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
let expected = &[
|
||||
64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0,
|
||||
78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0,
|
||||
92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0,
|
||||
105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0,
|
||||
117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,18 +14,15 @@ pub(crate) fn tile_outer_product<F: Float>(
|
|||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
// works
|
||||
results[0] = results[0] + register_m * register_n;
|
||||
// doesnt work
|
||||
results[0] += register_m * register_n;
|
||||
} else {
|
||||
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let res_pos_base = res_idx_m * Comptime::runtime(tile_size);
|
||||
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let mul = register_m[res_idx_m] * register_n[res_idx_n];
|
||||
// results[res_pos_base + res_idx_n] += mul;
|
||||
results[res_pos_base + res_idx_n] = results[res_pos_base + res_idx_n] + mul;
|
||||
results[res_pos_base + res_idx_n] += mul;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -57,7 +54,7 @@ pub mod tests {
|
|||
) {
|
||||
results[i] = F::new(0.);
|
||||
}
|
||||
tile_outer_product(register_m, register_n, results, config)
|
||||
tile_outer_product::<F>(register_m, register_n, results, config)
|
||||
}
|
||||
|
||||
fn test_case_config(tile_size: usize) -> CubeTiling2dConfig {
|
||||
|
@ -97,6 +94,38 @@ pub mod tests {
|
|||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn tile_outer_product_vectorized_unit_test_2<R: Runtime>(device: &R::Device) {
|
||||
let client = R::client(device);
|
||||
|
||||
let register_m = client.create(f32::as_bytes(&[16., 20., 24., 28.]));
|
||||
let register_n = client.create(f32::as_bytes(&[4., 5., 6., 7.]));
|
||||
let results = client.empty(16 * core::mem::size_of::<f32>());
|
||||
|
||||
// Unit test
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
let settings = KernelSettings::default().cube_dim(CubeDim::new(1, 1, 1));
|
||||
let config = test_case_config(4);
|
||||
|
||||
tile_outer_product_test_launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
ArrayHandle::new(®ister_m, 4),
|
||||
ArrayHandle::new(®ister_n, 4),
|
||||
ArrayHandle::new(&results, 16),
|
||||
config,
|
||||
);
|
||||
|
||||
let actual = client.read(results.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
let expected = &[
|
||||
64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0,
|
||||
140.0, 168.0, 196.0,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn tile_outer_product_scalar_unit_test<R: Runtime>(device: &R::Device) {
|
||||
let client = R::client(device);
|
||||
|
|
|
@ -22,24 +22,24 @@ pub(crate) fn tiling2d_core<F: Float>(
|
|||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let results = init_results(config);
|
||||
let results = init_results::<F>(config);
|
||||
|
||||
let n_loops = calculate_n_loops::<F>(lhs.shape(lhs.rank() - UInt::new(1)), config);
|
||||
|
||||
for k in range(0u32, n_loops, Comptime::new(false)) {
|
||||
let k = k * Comptime::runtime(block_size_k);
|
||||
|
||||
load_lhs_transposed(lhs, coordinates, k, offsets.lhs, shared.lhs, config);
|
||||
load_rhs_plain(rhs, coordinates, k, offsets.rhs, shared.rhs, config);
|
||||
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config);
|
||||
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config);
|
||||
|
||||
sync_units();
|
||||
|
||||
compute_loop(coordinates, shared.lhs, shared.rhs, results, config);
|
||||
compute_loop::<F>(coordinates, shared.lhs, shared.rhs, results, config);
|
||||
|
||||
sync_units();
|
||||
}
|
||||
|
||||
write_to_output(out, results, coordinates, offsets.out, config);
|
||||
write_to_output::<F>(out, results, coordinates, offsets.out, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
|
|
|
@ -20,7 +20,7 @@ pub(crate) fn write_to_output<F: Float>(
|
|||
let col = coordinates.skip_col + coordinates.unit_col;
|
||||
|
||||
let rank = out.rank();
|
||||
let out_stride_row = out.stride(rank - UInt::new(2)) / Comptime::runtime(tile_size);
|
||||
let out_stride_row = out.stride(rank - UInt::new(2));
|
||||
|
||||
if Comptime::get(check_m_bounds) {
|
||||
let dim_m = out.shape(rank - UInt::new(2));
|
||||
|
@ -28,7 +28,7 @@ pub(crate) fn write_to_output<F: Float>(
|
|||
let dim_n = out.shape(rank - UInt::new(1));
|
||||
if row < dim_m && col < dim_n {
|
||||
let num_writes = UInt::min(dim_m - row, Comptime::runtime(tile_size));
|
||||
write_results_to_output_partial(
|
||||
write_results_to_output_partial::<F>(
|
||||
out,
|
||||
results,
|
||||
row,
|
||||
|
@ -42,7 +42,7 @@ pub(crate) fn write_to_output<F: Float>(
|
|||
} else {
|
||||
if row < dim_m {
|
||||
let num_writes = UInt::min(dim_m - row, Comptime::runtime(tile_size));
|
||||
write_results_to_output_partial(
|
||||
write_results_to_output_partial::<F>(
|
||||
out,
|
||||
results,
|
||||
row,
|
||||
|
@ -58,7 +58,7 @@ pub(crate) fn write_to_output<F: Float>(
|
|||
if Comptime::get(check_n_bounds) {
|
||||
let dim_n = out.shape(rank - UInt::new(1));
|
||||
if col < dim_n {
|
||||
write_results_to_output(
|
||||
write_results_to_output::<F>(
|
||||
out,
|
||||
results,
|
||||
row,
|
||||
|
@ -69,7 +69,7 @@ pub(crate) fn write_to_output<F: Float>(
|
|||
);
|
||||
}
|
||||
} else {
|
||||
write_results_to_output(
|
||||
write_results_to_output::<F>(
|
||||
out,
|
||||
results,
|
||||
row,
|
||||
|
@ -93,14 +93,15 @@ fn write_results_to_output<F: Float>(
|
|||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let is_scalar = Comptime::map(tile_size, |t| t.val == 1);
|
||||
let sm_is_scalar = Comptime::map(tile_size, |t| t.val == 1);
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
out[row * out_stride_row + col + offset_output] = results[0];
|
||||
let vectorization_factor = Comptime::runtime(Comptime::vectorization(out));
|
||||
if Comptime::get(sm_is_scalar) {
|
||||
out[(row * out_stride_row + col + offset_output) / vectorization_factor] = results[0];
|
||||
} else {
|
||||
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
write_results_inner_loop(
|
||||
write_results_inner_loop::<F>(
|
||||
out,
|
||||
results,
|
||||
res_idx_m,
|
||||
|
@ -126,13 +127,14 @@ fn write_results_to_output_partial<F: Float>(
|
|||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let is_scalar = Comptime::map(tile_size, |t| t.val == 1);
|
||||
let sm_is_scalar = Comptime::map(tile_size, |t| t.val == 1);
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
out[row * out_stride_row + col + batch_offset / Comptime::runtime(tile_size)] = results[0];
|
||||
let vectorization_factor = Comptime::runtime(Comptime::vectorization(out));
|
||||
if Comptime::get(sm_is_scalar) {
|
||||
out[(row * out_stride_row + col + batch_offset) / vectorization_factor] = results[0];
|
||||
} else {
|
||||
for res_idx_m in range(0u32, num_writes, Comptime::new(false)) {
|
||||
write_results_inner_loop(
|
||||
write_results_inner_loop::<F>(
|
||||
out,
|
||||
results,
|
||||
res_idx_m,
|
||||
|
@ -161,16 +163,16 @@ fn write_results_inner_loop<F: Float>(
|
|||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
|
||||
let results_pos_m = res_idx_m * Comptime::runtime(tile_size);
|
||||
let out_position =
|
||||
(row + res_idx_m) * out_stride_row + col / Comptime::runtime(tile_size) + batch_offset;
|
||||
let out_position = (row + res_idx_m) * out_stride_row + col + batch_offset;
|
||||
|
||||
// Reinterpreting results as vectorized array
|
||||
let mut array = Array::<F>::new(Comptime::get(tile_size));
|
||||
let mut output = F::vectorized(0., Comptime::get(tile_size));
|
||||
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
array[res_idx_n] = results[results_pos_m + res_idx_n];
|
||||
output[res_idx_n] = results[results_pos_m + res_idx_n];
|
||||
}
|
||||
|
||||
out[out_position] = array.to_vectorized(tile_size);
|
||||
let vectorization_factor = Comptime::runtime(Comptime::vectorization(out));
|
||||
out[out_position / vectorization_factor] = output;
|
||||
}
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
|
@ -184,9 +186,8 @@ pub mod tests {
|
|||
results: Array<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let out_stride_row = out.stride(out.rank() - UInt::new(2)) / Comptime::runtime(tile_size);
|
||||
write_results_inner_loop(
|
||||
let out_stride_row = out.stride(out.rank() - UInt::new(2));
|
||||
write_results_inner_loop::<F>(
|
||||
out,
|
||||
results,
|
||||
UInt::new(2),
|
||||
|
@ -204,9 +205,8 @@ pub mod tests {
|
|||
results: Array<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let out_stride_row = out.stride(out.rank() - UInt::new(2)) / Comptime::runtime(tile_size);
|
||||
write_results_to_output(
|
||||
let out_stride_row = out.stride(out.rank() - UInt::new(2));
|
||||
write_results_to_output::<F>(
|
||||
out,
|
||||
results,
|
||||
UInt::new(4),
|
||||
|
@ -223,9 +223,8 @@ pub mod tests {
|
|||
results: Array<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let out_stride_row = out.stride(out.rank() - UInt::new(2)) / Comptime::runtime(tile_size);
|
||||
write_results_to_output_partial(
|
||||
let out_stride_row = out.stride(out.rank() - UInt::new(2));
|
||||
write_results_to_output_partial::<F>(
|
||||
out,
|
||||
results,
|
||||
UInt::new(4),
|
||||
|
@ -249,7 +248,7 @@ pub mod tests {
|
|||
skip_row: UInt::new(0),
|
||||
skip_col: UInt::new(0),
|
||||
};
|
||||
write_to_output(out, results, coordinates, UInt::new(0), config);
|
||||
write_to_output::<F>(out, results, coordinates, UInt::new(0), config);
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
|
@ -264,7 +263,7 @@ pub mod tests {
|
|||
skip_row: UInt::new(0),
|
||||
skip_col: UInt::new(0),
|
||||
};
|
||||
write_to_output(out, results, coordinates, UInt::new(0), config);
|
||||
write_to_output::<F>(out, results, coordinates, UInt::new(0), config);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
|
|
|
@ -478,7 +478,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
pub fn medium() {
|
||||
test_with_params(16, 16, 16, 1, 1);
|
||||
test_with_params(17, 16, 16, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -15,6 +15,13 @@ mod tests {
|
|||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_outer_product_vectorized_test_2() {
|
||||
outer_product_tests::tile_outer_product_vectorized_unit_test_2::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_outer_product_scalar_test() {
|
||||
outer_product_tests::tile_outer_product_scalar_unit_test::<TestRuntime>(&Default::default())
|
||||
|
@ -25,6 +32,11 @@ mod tests {
|
|||
compute_loop_tests::compute_loop_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn compute_loop_unit_offset_test() {
|
||||
compute_loop_tests::compute_loop_unit_offset_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_read_whole_vectorized_test() {
|
||||
load_shared_memory_tests::read_whole_unit_test::<TestRuntime>(&Default::default())
|
||||
|
@ -74,11 +86,33 @@ mod tests {
|
|||
load_shared_memory_tests::load_lhs_transposed_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_lhs_transposed_cube_test() {
|
||||
load_shared_memory_tests::load_lhs_transposed_cube_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_lhs_transposed_offset_cube_test() {
|
||||
load_shared_memory_tests::load_lhs_transposed_offset_cube_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_rhs_plain_unit_test() {
|
||||
load_shared_memory_tests::load_rhs_plain_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_rhs_plain_cube_test() {
|
||||
load_shared_memory_tests::load_rhs_plain_cube_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_rhs_plain_cube_offset_test() {
|
||||
load_shared_memory_tests::load_rhs_plain_cube_offset_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_results_inner_loop_unit_test() {
|
||||
write_output_tests::write_results_inner_loop_unit_test::<TestRuntime>(&Default::default())
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(matmul)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_matmul_d2() {
|
||||
|
@ -111,14 +111,102 @@ mod tests {
|
|||
#[test]
|
||||
fn test_matmul_trivial() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::from_floats([[3., 4.]], &device);
|
||||
let tensor_2 = TestTensor::from_floats([[4.], [3.]], &device);
|
||||
|
||||
// let tensor_3 = tensor_1.matmul(tensor_2);
|
||||
let tensor_3 = tensor_2.matmul(tensor_1);
|
||||
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..16, &device)
|
||||
.reshape([4, 4])
|
||||
.float();
|
||||
|
||||
// assert_eq!(tensor_3.into_data(), Data::from([[24.]]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([[12., 16.], [9., 12.]]));
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_1);
|
||||
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([
|
||||
[56., 62., 68., 74.],
|
||||
[152., 174., 196., 218.],
|
||||
[248., 286., 324., 362.],
|
||||
[344., 398., 452., 506.]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_trivial_transposed() {
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..16, &device)
|
||||
.reshape([4, 4])
|
||||
.float();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());
|
||||
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([
|
||||
[14., 38., 62., 86.],
|
||||
[38., 126., 214., 302.],
|
||||
[62., 214., 366., 518.],
|
||||
[86., 302., 518., 734.]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_4_8() {
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..32, &device)
|
||||
.reshape([4, 8])
|
||||
.float();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());
|
||||
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([
|
||||
[140., 364., 588., 812.],
|
||||
[364., 1100., 1836., 2572.],
|
||||
[588., 1836., 3084., 4332.],
|
||||
[812., 2572., 4332., 6092.]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_8_4() {
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..32, &device)
|
||||
.reshape([8, 4])
|
||||
.float();
|
||||
let tensor_2 = Tensor::<TestBackend, 1, Int>::arange(0..32, &device)
|
||||
.reshape([4, 8])
|
||||
.float();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2);
|
||||
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([
|
||||
[112., 118., 124., 130., 136., 142., 148., 154.],
|
||||
[304., 326., 348., 370., 392., 414., 436., 458.],
|
||||
[496., 534., 572., 610., 648., 686., 724., 762.],
|
||||
[688., 742., 796., 850., 904., 958., 1012., 1066.],
|
||||
[880., 950., 1020., 1090., 1160., 1230., 1300., 1370.],
|
||||
[1072., 1158., 1244., 1330., 1416., 1502., 1588., 1674.],
|
||||
[1264., 1366., 1468., 1570., 1672., 1774., 1876., 1978.],
|
||||
[1456., 1574., 1692., 1810., 1928., 2046., 2164., 2282.]
|
||||
],)
|
||||
);
|
||||
// [
|
||||
// [112.0, 118.0, 124.0, 130.0, 880.0, 950.0, 1020.0, 1090.0],
|
||||
// [304.0, 326.0, 348.0, 370.0, 1072.0, 1158.0, 1244.0, 1330.0],
|
||||
// [496.0, 534.0, 572.0, 610.0, 1264.0, 1366.0, 1468.0, 1570.0],
|
||||
// [688.0, 742.0, 796.0, 850.0, 1456.0, 1574.0, 1692.0, 1810.0],
|
||||
// [136.0, 142.0, 148.0, 154.0, 1160.0, 1230.0, 1300.0, 1370.0],
|
||||
// [392.0, 414.0, 436.0, 458.0, 1416.0, 1502.0, 1588.0, 1674.0],
|
||||
// [648.0, 686.0, 724.0, 762.0, 1672.0, 1774.0, 1876.0, 1978.0],
|
||||
// [904.0, 958.0, 1012.0, 1066.0, 1928.0, 2046.0, 2164.0, 2282.0],
|
||||
// ]
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -151,7 +151,7 @@ impl Display for LearnerSummary {
|
|||
)?;
|
||||
|
||||
if let Some(model) = &self.model {
|
||||
writeln!(f, "Model: {model}")?;
|
||||
writeln!(f, "Model:\n{model}")?;
|
||||
}
|
||||
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ where
|
|||
}
|
||||
|
||||
let compile = kernel.compile();
|
||||
println!("{}", compile.source);
|
||||
// println!("{}", compile.source);
|
||||
let pipeline = self.compile_source(&compile.source);
|
||||
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
|
Loading…
Reference in New Issue