some debugging done

This commit is contained in:
louisfd 2024-06-26 12:17:27 -04:00
commit d9b4801448
52 changed files with 2108 additions and 340 deletions

125
Cargo.lock generated
View File

@ -171,7 +171,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -197,7 +197,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -463,6 +463,7 @@ version = "0.14.0"
dependencies = [
"async-trait",
"dashmap",
"data-encoding",
"derive-new",
"getrandom",
"indicatif",
@ -471,7 +472,6 @@ dependencies = [
"serde",
"spin",
"tokio",
"uuid",
"web-time",
]
@ -546,7 +546,7 @@ dependencies = [
"derive-new",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -604,7 +604,7 @@ dependencies = [
"derive-new",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -642,7 +642,7 @@ dependencies = [
"serde_json",
"strum",
"strum_macros",
"syn 2.0.66",
"syn 2.0.68",
"thiserror",
"tracing-core",
"tracing-subscriber",
@ -777,9 +777,9 @@ dependencies = [
[[package]]
name = "bytemuck"
version = "1.16.0"
version = "1.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5"
checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e"
dependencies = [
"bytemuck_derive",
]
@ -792,7 +792,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -843,7 +843,8 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.5.1"
source = "git+https://github.com/huggingface/candle.git?rev=82b641f#82b641fd2752e3b14db6a9c91faef70e3329f3b5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "311d8dbe293aa3b5c34f6a57727fafd67d17a74fa8b65276501237c233b34ffd"
dependencies = [
"accelerate-src",
"byteorder",
@ -869,7 +870,8 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.5.1"
source = "git+https://github.com/huggingface/candle.git?rev=82b641f#82b641fd2752e3b14db6a9c91faef70e3329f3b5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3b4b048ca298fb8be90b0f4d0fe68bdca9de956ab52bb6e381463d955f2b661"
dependencies = [
"bindgen_cuda",
]
@ -877,7 +879,8 @@ dependencies = [
[[package]]
name = "candle-metal-kernels"
version = "0.5.1"
source = "git+https://github.com/huggingface/candle.git?rev=82b641f#82b641fd2752e3b14db6a9c91faef70e3329f3b5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d31136c9541c82b7de0937c9a58210ada38e17d70810e0eacc0a99d849d848d"
dependencies = [
"metal 0.27.0",
"once_cell",
@ -1027,7 +1030,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1439,7 +1442,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1450,7 +1453,7 @@ checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
dependencies = [
"darling_core",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1466,6 +1469,12 @@ dependencies = [
"parking_lot_core 0.9.10",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
[[package]]
name = "deflate64"
version = "0.1.8"
@ -1489,7 +1498,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1500,7 +1509,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1521,7 +1530,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1531,7 +1540,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
dependencies = [
"derive_builder_core",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1542,7 +1551,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1617,7 +1626,7 @@ checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1669,7 +1678,7 @@ dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1855,7 +1864,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -1946,7 +1955,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -2879,7 +2888,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -3279,7 +3288,7 @@ checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -3439,7 +3448,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -3701,7 +3710,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -3861,7 +3870,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -3961,9 +3970,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.85"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [
"unicode-ident",
]
@ -3984,7 +3993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -4535,7 +4544,7 @@ dependencies = [
"regex",
"relative-path",
"rustc_version",
"syn 2.0.66",
"syn 2.0.68",
"unicode-ident",
]
@ -4813,7 +4822,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -4880,7 +4889,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5030,7 +5039,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ff9eaf853dec4c8802325d8b6d3dffa86cc707fd7a1a4cdbf416e13b061787a"
dependencies = [
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5059,9 +5068,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.26.2"
version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros",
]
@ -5076,7 +5085,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5098,9 +5107,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.66"
version = "2.0.68"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9"
dependencies = [
"proc-macro2",
"quote",
@ -5121,7 +5130,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5312,7 +5321,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5450,7 +5459,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5586,7 +5595,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -5849,7 +5858,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
"wasm-bindgen-shared",
]
@ -5883,7 +5892,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@ -5957,9 +5966,9 @@ checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "wgpu"
version = "0.20.0"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32ff1bfee408e1028e2e3acbf6d32d98b08a5a059ccbf5f33305534453ba5d3e"
checksum = "90e37c7b9921b75dfd26dd973fdcbce36f13dfa6e2dc82aece584e0ed48c355c"
dependencies = [
"arrayvec",
"cfg-if",
@ -5983,9 +5992,9 @@ dependencies = [
[[package]]
name = "wgpu-core"
version = "0.20.0"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6a86eaa5e763e59c73cf9e97d55fffd4dfda69fd8bda19589fcf851ddfef1f"
checksum = "d59e0d5fc509601c69e4e1fa06c1eb3c4c9f12956a5e30c79b61ef1c1be7daf0"
dependencies = [
"arrayvec",
"bit-vec",
@ -6010,9 +6019,9 @@ dependencies = [
[[package]]
name = "wgpu-hal"
version = "0.20.0"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d71c8ae05170583049b65ee562fd839fdc0b3e9ddb84f4e40c9d5f8ea0d4c8c"
checksum = "6aa24c3889f885a3fb9133b454c8418bfcfaadcfe4ed3be96ac80e76703b863b"
dependencies = [
"android_system_properties",
"arrayvec",
@ -6309,7 +6318,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -6392,7 +6401,7 @@ checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
"synstructure",
]
@ -6413,7 +6422,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]
@ -6433,7 +6442,7 @@ checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
"synstructure",
]
@ -6454,7 +6463,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
"syn 2.0.68",
]
[[package]]

View File

@ -27,14 +27,16 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies]
async-trait = "0.1.80"
bytemuck = "1.16.0"
# candle-core = { version = "0.4.1" }
candle-core = { git = "https://github.com/huggingface/candle.git", rev = "82b641f" }
bytemuck = "1.16.1"
candle-core = { version = "0.5.1" }
clap = { version = "4.5.7", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
csv = "1.3.0"
dashmap = "5.5.3"
data-encoding = { version = "2.6.0", default-features = false, features = [
"alloc",
] }
dirs = "5.0.1"
fake = "2.9.2"
flate2 = "1.0.30"
@ -43,16 +45,19 @@ getrandom = { version = "0.2.15", default-features = false }
gix-tempfile = { version = "13.1.1", features = ["signals"] }
globwalk = "0.9.1"
hashbrown = "0.14.5"
hound = "3.5.1"
image = "0.25.1"
indicatif = "0.17.8"
js-sys = "0.3.69"
libm = "0.2.8"
log = { default-features = false, version = "0.4.21" }
md5 = "0.7.0"
percent-encoding = "2.3.1"
pretty_assertions = "1.4.0"
proc-macro2 = "1.0.85"
proc-macro2 = "1.0.86"
protobuf = "3.4.0"
protobuf-codegen = "3.4.0"
quote = "1.0.36"
percent-encoding = "2.3.1"
r2d2 = "0.8.10"
r2d2_sqlite = { version = "0.24.0" }
rayon = "1.10.0"
@ -64,21 +69,18 @@ rusqlite = { version = "0.31.0" }
rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0"
serde_rusqlite = "0.35.0"
serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
strum = "0.26.2"
strum = "0.26.3"
strum_macros = "0.26.4"
syn = { version = "2.0.66", features = ["full", "extra-traits"] }
syn = { version = "2.0.68", features = ["full", "extra-traits"] }
tempfile = "3.10.1"
thiserror = "1.0.61"
tokio = { version = "1.38.0", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.32"
tracing-subscriber = "0.3.18"
md5 = "0.7.0"
serial_test = "3.1.1"
web-time = "1.1.0"
hound = "3.5.1"
image = "0.25.1"
zip = "2.1.3"
# Terminal UI
@ -89,7 +91,7 @@ crossterm = "0.27.0"
futures-intrusive = "0.5.0"
text_placeholder = "0.5.0"
pollster = "0.3.0"
wgpu = "0.20.0"
wgpu = "0.20.1"
# Benchmarks and Burnbench
arboard = "3.4.0"

View File

@ -1,12 +1,12 @@
# Learner
The [burn-train](https://github.com/tracel-ai/burn/tree/main/burn-train) crate encapsulates multiple
utilities for training deep learning models. The goal of the crate is to provide users with a
well-crafted and flexible training loop, so that projects do not have to write such components from
the ground up. Most of the interactions with `burn-train` will be with the `LearnerBuilder` struct,
briefly presented in the previous [training section](../basic-workflow/training.md). This struct
enables you to configure the training loop, offering support for registering metrics, enabling
logging, checkpointing states, using multiple devices, and so on.
The [burn-train](https://github.com/tracel-ai/burn/tree/main/crates/burn-train) crate encapsulates
multiple utilities for training deep learning models. The goal of the crate is to provide users with
a well-crafted and flexible training loop, so that projects do not have to write such components
from the ground up. Most of the interactions with `burn-train` will be with the `LearnerBuilder`
struct, briefly presented in the previous [training section](../basic-workflow/training.md). This
struct enables you to configure the training loop, offering support for registering metrics,
enabling logging, checkpointing states, using multiple devices, and so on.
There are still some assumptions in the current provided APIs, which may make them inappropriate for
your learning requirements. Indeed, they assume your model will learn from a training dataset and be

View File

@ -12,7 +12,7 @@ version.workspace = true
[features]
default = ["std"]
std = ["rand/std"]
std = ["rand/std", "data-encoding/std"]
doc = ["default"]
wasm-sync = []
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
@ -27,10 +27,10 @@ web-time = { version = "1.1.0" }
# ** Please make sure all dependencies support no_std when std is disabled **
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
uuid = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
derive-new = { workspace = true }
serde = { workspace = true }
data-encoding = { workspace = true }
# Network downloader
indicatif = { workspace = true, optional = true }

View File

@ -1,18 +1,21 @@
use alloc::string::String;
use crate::rand::gen_random;
use alloc::string::{String, ToString};
use uuid::{Builder, Bytes};
use data_encoding::BASE32_DNSSEC;
/// Simple ID generator.
pub struct IdGenerator {}
impl IdGenerator {
/// Generates a new ID in the form of a UUID.
/// Generates a new ID.
pub fn generate() -> String {
let random_bytes: Bytes = gen_random();
// Generate 6 random bytes (281,474,976,710,656 combinations)
let random_bytes: [u8; 6] = gen_random();
let uuid = Builder::from_random_bytes(random_bytes).into_uuid();
uuid.as_hyphenated().to_string()
// Encode the random bytes in base32 DNSSEC
// 6 bytes encodes to 10 lower case characters, e.g. "3uu5e6vv7c"
BASE32_DNSSEC.encode(&random_bytes)
}
}

View File

@ -0,0 +1,547 @@
use alloc::{
borrow::ToOwned,
format,
string::{String, ToString},
vec::Vec,
};
use core::any;
use core::fmt::{Display, Write};
/// Default display settings for a module.
pub trait ModuleDisplayDefault {
/// Attributes of the module used for display purposes.
///
/// # Arguments
///
/// * `_content` - The content object that contains display settings and attributes.
///
/// # Returns
///
/// An optional content object containing the display attributes.
fn content(&self, _content: Content) -> Option<Content>;
/// Gets the number of the parameters of the module.
fn num_params(&self) -> usize {
0
}
}
/// Trait to implement custom display settings for a module.
///
/// In order to implement custom display settings for a module,
/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
/// 2. Implement ModuleDisplay trait for the module
pub trait ModuleDisplay: ModuleDisplayDefault {
/// Formats the module with provided display settings.
///
/// # Arguments
///
/// * `passed_settings` - Display settings passed to the module.
///
/// # Returns
///
/// A string representation of the formatted module.
fn format(&self, passed_settings: DisplaySettings) -> String {
let settings = if let Some(custom_settings) = self.custom_settings() {
custom_settings.inherit(passed_settings)
} else {
passed_settings
};
let indent = " ".repeat(settings.level * settings.indentation_size());
let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
let settings = settings.level_up();
let self_type = extract_type_name::<Self>();
// Use custom content if it is implemented and show_all_attributes is false,
// otherwise use default content
let content = if !settings.show_all_attributes() {
self.custom_content(Content::new(settings.clone()))
.unwrap_or_else(|| {
self.content(Content::new(settings.clone()))
.unwrap_or_else(|| {
panic!("Default content should be implemented for {self_type}.")
})
})
} else {
self.content(Content::new(settings.clone()))
.unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
};
let top_level_type = if let Some(top_level_type) = content.top_level_type {
top_level_type.to_owned()
} else {
self_type.to_owned()
};
// If there is only one item in the content, return it or no attributes
if let Some(item) = content.single_item {
return item;
} else if content.attributes.is_empty() {
return top_level_type.to_string();
}
let mut result = String::new();
// Print the struct name
if settings.new_line_after_attribute() {
writeln!(result, "{} {{", top_level_type).unwrap();
} else {
write!(result, "{} {{", top_level_type).unwrap();
}
for (i, attribute) in content.attributes.iter().enumerate() {
if settings.new_line_after_attribute() {
writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
} else if i == 0 {
write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
} else {
write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
}
}
if settings.show_num_parameters() {
let num_params = self.num_params();
if num_params > 0 {
if settings.new_line_after_attribute() {
writeln!(result, "{indent}params: {}", num_params).unwrap();
} else {
write!(result, ", params: {}", num_params).unwrap();
}
}
}
if settings.new_line_after_attribute() {
write!(result, "{indent_close_braces}}}").unwrap();
} else {
write!(result, "}}").unwrap();
}
result
}
/// Custom display settings for the module.
///
/// # Returns
///
/// An optional display settings object.
fn custom_settings(&self) -> Option<DisplaySettings> {
None
}
/// Custom attributes for the module.
///
/// # Arguments
///
/// * `_content` - The content object that contains display settings and attributes.
///
/// # Returns
///
/// An optional content object containing the custom attributes.
fn custom_content(&self, _content: Content) -> Option<Content> {
None
}
}
/// Custom module display settings.
#[derive(Debug, Clone)]
pub struct DisplaySettings {
/// Whether to print the module parameter ids.
show_param_id: Option<bool>,
/// Whether to print the module attributes.
show_all_attributes: Option<bool>,
/// Whether to print the module number of parameters.
show_num_parameters: Option<bool>,
/// Print new line after an attribute.
new_line_after_attribute: Option<bool>,
/// Indentation size.
indentation_size: Option<usize>,
/// Level of indentation.
level: usize,
}
impl Default for DisplaySettings {
fn default() -> Self {
DisplaySettings {
show_param_id: None,
show_all_attributes: None,
show_num_parameters: None,
new_line_after_attribute: None,
indentation_size: None,
level: 1,
}
}
}
impl DisplaySettings {
/// Create a new format settings.
///
/// # Returns
///
/// A new instance of `DisplaySettings`.
pub fn new() -> Self {
Default::default()
}
/// Sets a flag to show module parameters.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show module parameters.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_param_id(mut self, flag: bool) -> Self {
self.show_param_id = Some(flag);
self
}
/// Sets a flag to show module attributes.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show all module attributes.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
self.show_all_attributes = Some(flag);
self
}
/// Sets a flag to show the number of module parameters.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show the number of module parameters.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
self.show_num_parameters = Some(flag);
self
}
/// Sets a flag to print a new line after an attribute.
///
/// # Arguments
///
/// * `flag` - Boolean flag to print a new line after an attribute.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
self.new_line_after_attribute = Some(flag);
self
}
/// Sets the indentation size.
///
/// # Arguments
///
/// * `size` - The size of the indentation.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_indentation_size(mut self, size: usize) -> Self {
self.indentation_size = Some(size);
self
}
/// Inherits settings from the provided settings and return a new settings object.
///
/// # Arguments
///
/// * `top` - The top level `DisplaySettings` to inherit from.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn inherit(self, top: Self) -> Self {
let mut updated = self.clone();
if let Some(show_param_id) = top.show_param_id {
updated.show_param_id = Some(show_param_id);
};
if let Some(show_all_attributes) = top.show_all_attributes {
updated.show_all_attributes = Some(show_all_attributes);
}
if let Some(show_num_parameters) = top.show_num_parameters {
updated.show_num_parameters = Some(show_num_parameters);
}
if let Some(new_line_after_attribute) = top.new_line_after_attribute {
updated.new_line_after_attribute = Some(new_line_after_attribute);
}
if let Some(indentation_size) = top.indentation_size {
updated.indentation_size = Some(indentation_size);
}
updated.level = top.level;
updated
}
/// A convenience method to wrap the DisplaySettings struct in an option.
///
/// # Returns
///
/// An optional `DisplaySettings`.
pub fn optional(self) -> Option<Self> {
Some(self)
}
/// Increases the level of indentation.
///
/// # Returns
///
/// Updated `DisplaySettings` instance with increased indentation level.
pub fn level_up(mut self) -> Self {
self.level += 1;
self
}
/// Gets `show_param_id` flag, substitutes false if not set.
///
/// This flag is used to print the module parameter ids.
///
/// # Returns
///
/// A boolean value indicating whether to show parameter ids.
pub fn show_param_id(&self) -> bool {
self.show_param_id.unwrap_or(false)
}
/// Gets `show_all_attributes`, substitutes false if not set.
///
/// This flag is used to force to print all module attributes, overriding custom attributes.
///
/// # Returns
///
/// A boolean value indicating whether to show all attributes.
pub fn show_all_attributes(&self) -> bool {
self.show_all_attributes.unwrap_or(false)
}
/// Gets `show_num_parameters`, substitutes true if not set.
///
/// This flag is used to print the number of module parameters.
///
/// # Returns
///
/// A boolean value indicating whether to show the number of parameters.
pub fn show_num_parameters(&self) -> bool {
self.show_num_parameters.unwrap_or(true)
}
/// Gets `new_line_after_attribute`, substitutes true if not set.
///
/// This flag is used to print a new line after an attribute.
///
/// # Returns
///
/// A boolean value indicating whether to print a new line after an attribute.
pub fn new_line_after_attribute(&self) -> bool {
self.new_line_after_attribute.unwrap_or(true)
}
/// Gets `indentation_size`, substitutes 2 if not set.
///
/// This flag is used to set the size of indentation.
///
/// # Returns
///
/// An integer value indicating the size of indentation.
pub fn indentation_size(&self) -> usize {
self.indentation_size.unwrap_or(2)
}
}
/// Struct to store the attributes of a module for formatting.
#[derive(Clone, Debug)]
pub struct Content {
/// List of attributes.
pub attributes: Vec<Attribute>,
/// Single item content.
pub single_item: Option<String>,
/// Display settings.
pub display_settings: DisplaySettings,
/// Top level type name.
pub top_level_type: Option<String>,
}
impl Content {
/// Creates a new attributes struct.
///
/// # Arguments
///
/// * `display_settings` - Display settings for the content.
///
/// # Returns
///
/// A new instance of `Content`.
pub fn new(display_settings: DisplaySettings) -> Self {
Content {
attributes: Vec::new(),
single_item: None,
display_settings,
top_level_type: None,
}
}
/// Adds an attribute to the format settings. The value will be formatted and stored as a string.
///
/// # Arguments
///
/// * `name` - Name of the attribute.
/// * `value` - Value of the attribute.
///
/// # Returns
///
/// Updated `Content` instance with the new attribute added.
pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
if self.single_item.is_some() {
panic!("Cannot add multiple attributes when single item is set.");
}
let attribute = Attribute {
name: name.to_owned(),
value: value.format(self.display_settings.clone()), // TODO level + 1
ty: any::type_name::<T>().to_string(),
};
self.attributes.push(attribute);
self
}
/// Adds a single item.
///
/// # Arguments
///
/// * `value` - Rendered string of the single item.
///
/// # Returns
///
/// Updated `Content` instance with the single item added.
pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
if !self.attributes.is_empty() {
panic!("Cannot add single item when attributes are set.");
}
self.single_item = Some(value.format(self.display_settings.clone()));
self
}
/// Adds a single item.
///
/// # Arguments
///
/// * `value` - Formatted display value.
///
/// # Returns
///
/// Updated `Content` instance with the formatted single item added.
pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
if !self.attributes.is_empty() {
panic!("Cannot add single item when attributes are set.");
}
self.single_item = Some(format!("{}", value));
self
}
/// A convenience method to wrap the Attributes struct in an option
/// because it is often used as an optional field.
///
/// # Returns
///
/// An optional `Content`.
pub fn optional(self) -> Option<Self> {
if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
{
None
} else {
Some(self)
}
}
/// Sets the top level type name.
///
/// # Arguments
///
/// * `ty` - The type name to set.
///
/// # Returns
///
/// Updated `Content` instance with the top level type name set.
pub fn set_top_level_type(mut self, ty: &str) -> Self {
self.top_level_type = Some(ty.to_owned());
self
}
}
/// Attribute to print in the display method.
#[derive(Clone, Debug)]
pub struct Attribute {
/// Name of the attribute.
pub name: String,
/// Value of the attribute.
pub value: String,
/// Type of the attribute.
pub ty: String,
}
/// Extracts the short name of a type T
///
/// # Returns
///
/// A string slice representing the short name of the type.
pub fn extract_type_name<T: ?Sized>() -> &'static str {
// Get the full type name of T, including module path and generic parameters
let ty = any::type_name::<T>();
// Find the first occurrence of '<' in the full type name
// If not found, use the length of the type name
let end = ty.find('<').unwrap_or(ty.len());
// Slice the type name up to the first '<' or the end
let ty = &ty[0..end];
// Find the last occurrence of "::" in the sliced type name
// If found, add 2 to skip the "::" itself
// If not found, start from the beginning of the type name
let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
// Find the last occurrence of '<' in the sliced type name
// If not found, use the length of the type name
let end = ty.rfind('<').unwrap_or(ty.len());
// If the start index is less than the end index,
// return the slice of the type name from start to end
// Otherwise, return the entire sliced type name
if start < end {
&ty[start..end]
} else {
ty
}
}

View File

@ -1,5 +1,7 @@
mod base;
mod display;
mod param;
pub use base::*;
pub use display::*;
pub use param::*;

View File

@ -1,6 +1,12 @@
use alloc::{format, string::ToString};
use core::{fmt::Display, marker::PhantomData};
use crate::{
self as burn,
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
module::{
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
ModuleMapper, ModuleVisitor,
},
record::Record,
};
use burn::record::PrecisionSettings;
@ -8,7 +14,6 @@ use burn_tensor::{
backend::{AutodiffBackend, Backend},
BasicAutodiffOps, BasicOps, Tensor,
};
use core::marker::PhantomData;
/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new, Default)]
@ -96,6 +101,15 @@ macro_rules! constant {
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
constant!(ad_module, $type);
}
impl burn::module::ModuleDisplayDefault for $type {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
let string = format!("{}", self);
content.add_formatted(&string).optional()
}
}
impl burn::module::ModuleDisplay for $type {}
};
}
@ -122,6 +136,13 @@ constant!(i32);
constant!(i16);
constant!(i8);
impl burn::module::ModuleDisplay for str {}
impl burn::module::ModuleDisplayDefault for str {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
content.add_formatted(&self).optional()
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
type Record = ConstantRecord;
@ -158,6 +179,15 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
fn content(&self, content: Content) -> Option<Content> {
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
content.add_single(&string).optional()
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
@ -200,6 +230,14 @@ impl<B: Backend> Module<B> for PhantomData<B> {
}
}
impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
fn content(&self, content: Content) -> Option<Content> {
content.add_single(&"PhantomData".to_string()).optional()
}
}
impl<B: Backend> ModuleDisplay for PhantomData<B> {}
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
type InnerModule = PhantomData<B::InnerBackend>;
@ -248,6 +286,27 @@ where
}
}
impl<T> ModuleDisplayDefault for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn content(&self, content: Content) -> Option<Content> {
// For now, just print the debug representation of the ignored value
content.add_single(&format!("{:?}", self.0)).optional()
}
}
impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
impl<T> Display for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
where
B: AutodiffBackend,

View File

@ -1,5 +1,10 @@
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
use alloc::vec::Vec;
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use alloc::{format, vec::Vec};
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::fmt::Debug;
@ -52,6 +57,17 @@ where
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
fn content(&self, content: Content) -> Option<Content> {
match self {
Some(module) => content.add_single(module).optional(),
None => content.add_single("None").optional(),
}
}
}
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
impl<T, B> AutodiffModule<B> for Option<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
@ -128,6 +144,21 @@ where
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{}", i);
acc.add(&index, module)
})
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
.optional()
}
}
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
impl<T, B> AutodiffModule<B> for Vec<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
@ -197,6 +228,21 @@ where
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{}", i);
acc.add(&index, module)
})
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
.optional()
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
@ -269,6 +315,21 @@ macro_rules! impl_module_tuple {
($(self.$i.valid(),)*)
}
}
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
where
$($l: ModuleDisplay,)*
{
fn content(&self, content: Content) -> Option<Content> {
let content = content
$(.add(&format!("{}", $i), &self.$i))*
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
content.optional()
}
}
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
};
}

View File

@ -1,7 +1,13 @@
use super::ParamId;
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor, Param,
};
use alloc::string::ToString;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::stub::Mutex;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
@ -45,6 +51,24 @@ pub struct RunningState<V> {
value: Arc<Mutex<V>>,
}
// Implement display for the module
impl<V> core::fmt::Display for RunningState<V> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "RunningState(id={})", self.id)
}
}
impl<V> ModuleDisplayDefault for RunningState<V> {
fn content(&self, content: Content) -> Option<Content> {
content
.add_formatted(&"RunningState".to_string())
.optional()
}
}
impl<V> ModuleDisplay for RunningState<V> {}
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;

View File

@ -1,10 +1,13 @@
use super::{Param, ParamId, Parameter};
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use crate::tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
};
use alloc::vec::Vec;
use alloc::{format, string::ToString, vec::Vec};
use burn_tensor::{Bool, Data, Float, Int};
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
@ -147,6 +150,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
self.shape().dims
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
type Record = Param<Tensor<B, D, Int>>;
@ -198,6 +217,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
self.shape().dims
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
type Record = Param<Tensor<B, D, Bool>>;
@ -249,6 +284,24 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
self.shape().dims
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
type InnerModule = Param<Tensor<B::InnerBackend, D>>;

View File

@ -1,14 +1,13 @@
use alloc::format;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::nn::conv::checks;
use crate::nn::{Initializer, PaddingConfig1d};
use crate::tensor::backend::Backend;
use crate::tensor::module::conv1d;
use crate::tensor::ops::ConvOptions;
use crate::tensor::Tensor;
use crate::{
config::Config,
module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param},
nn::{conv::checks, Initializer, PaddingConfig1d},
tensor::{backend::Backend, module::conv1d, ops::ConvOptions, Tensor},
};
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
#[derive(Config, Debug)]
@ -45,6 +44,7 @@ pub struct Conv1dConfig {
///
/// Should be created with [Conv1dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv1d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
pub weight: Param<Tensor<B, 3>>,
@ -54,7 +54,28 @@ pub struct Conv1d<B: Backend> {
kernel_size: usize,
dilation: usize,
groups: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
}
impl<B: Backend> ModuleDisplay for Conv1d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
content
.add("stride", &self.stride)
.add("kernel_size", &self.kernel_size)
.add("dilation", &self.dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl Conv1dConfig {
@ -87,7 +108,7 @@ impl Conv1dConfig {
bias,
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
groups: self.groups,
}

View File

@ -1,8 +1,9 @@
use alloc::format;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
use crate::nn::Initializer;
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
@ -45,6 +46,7 @@ pub struct Conv2dConfig {
///
/// Should be created with [Conv2dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv2d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
@ -54,7 +56,7 @@ pub struct Conv2d<B: Backend> {
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
}
impl Conv2dConfig {
@ -93,12 +95,38 @@ impl Conv2dConfig {
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
groups: self.groups,
}
}
}
impl<B: Backend> ModuleDisplay for Conv2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
content
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl<B: Backend> Conv2d<B> {
/// Applies the forward pass on the input tensor.
///

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};
@ -21,6 +21,7 @@ pub struct DropoutConfig {
///
/// Should be created with [DropoutConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Dropout {
prob: f64,
}
@ -54,6 +55,18 @@ impl Dropout {
}
}
impl ModuleDisplay for Dropout {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
content.add("prob", &self.prob).optional()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,4 +1,6 @@
use crate as burn;
use crate::module::DisplaySettings;
use crate::module::ModuleDisplay;
use crate::config::Config;
use crate::module::Module;
@ -30,6 +32,7 @@ pub struct LinearConfig {
///
/// `O = IW + b`
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Linear<B: Backend> {
/// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
@ -83,6 +86,23 @@ impl<B: Backend> Linear<B> {
}
}
impl<B: Backend> ModuleDisplay for Linear<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
let [d_input, d_output] = self.weight.shape().dims;
content
.add("d_input", &d_input)
.add("d_output", &d_output)
.add("bias", &self.bias.is_some())
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,4 +1,5 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::nn::Initializer;
use crate::{
@ -33,6 +34,7 @@ pub struct BatchNormConfig {
///
/// Should be created using [BatchNormConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BatchNorm<B: Backend, const D: usize> {
/// The learnable weight gamma.
pub gamma: Param<Tensor<B, 1>>,
@ -183,6 +185,24 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
}
}
impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [num_features] = self.beta.shape().dims;
content
.add("num_features", &num_features)
.add("momentum", &self.momentum)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests_1d {

View File

@ -1,7 +1,8 @@
use crate as burn;
use crate::config::Config;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
use crate::module::Param;
use crate::nn::Initializer;
use crate::tensor::backend::Backend;
@ -29,6 +30,7 @@ pub struct LayerNormConfig {
///
/// Should be created using [LayerNormConfig](LayerNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct LayerNorm<B: Backend> {
/// The learnable weight.
gamma: Param<Tensor<B, 1>>,
@ -71,6 +73,22 @@ impl<B: Backend> LayerNorm<B> {
}
}
impl<B: Backend> ModuleDisplay for LayerNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
let [d_model] = self.gamma.shape().dims;
content
.add("d_model", &d_model)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -3,10 +3,9 @@ use crate as burn;
use crate::tensor::ops::conv::calculate_conv_padding;
use crate::config::Config;
use crate::module::Module;
/// Padding configuration for 1D operators.
#[derive(Module, Config, Debug, PartialEq)]
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig1d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
@ -34,7 +33,7 @@ impl PaddingConfig1d {
}
/// Padding configuration for 2D operators.
#[derive(Module, Config, Debug, PartialEq)]
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig2d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -43,7 +43,7 @@ pub struct AvgPool1dConfig {
pub struct AvgPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
count_include_pad: bool,
}
@ -53,7 +53,7 @@ impl AvgPool1dConfig {
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -42,7 +42,7 @@ pub struct AvgPool2dConfig {
pub struct AvgPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
count_include_pad: bool,
}
@ -52,7 +52,7 @@ impl AvgPool2dConfig {
AvgPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -31,7 +31,7 @@ pub struct MaxPool1dConfig {
pub struct MaxPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
dilation: usize,
}
@ -41,7 +41,7 @@ impl MaxPool1dConfig {
MaxPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -31,7 +31,7 @@ pub struct MaxPool2dConfig {
pub struct MaxPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
dilation: [usize; 2],
}
@ -41,7 +41,7 @@ impl MaxPool2dConfig {
MaxPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
}
}

View File

@ -1,12 +1,11 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::ops::UnfoldOptions;
use crate::tensor::Tensor;
use crate::tensor::module::unfold4d;
use crate::module::{Ignored, Module};
use burn_tensor::backend::Backend;
use burn_tensor::module::unfold4d;
use burn_tensor::ops::UnfoldOptions;
use burn_tensor::Tensor;
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
#[derive(Config, Debug)]
@ -29,14 +28,14 @@ pub struct Unfold4dConfig {
/// Should be created with [Unfold4dConfig].
#[derive(Module, Clone, Debug)]
pub struct Unfold4d {
config: Unfold4dConfig,
config: Ignored<Unfold4dConfig>,
}
impl Unfold4dConfig {
/// Initializes a new [Unfold4d] module.
pub fn init(&self) -> Unfold4d {
Unfold4d {
config: self.clone(),
config: Ignored(self.clone()),
}
}
}
@ -48,7 +47,7 @@ impl Unfold4d {
///
/// # Shapes
///
/// input: `[batch_size, channels_in, height, width]`
/// input: `[batch_size, channels_in, height, width]`
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
unfold4d(

View File

@ -370,6 +370,6 @@ mod tests {
// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 134);
assert_eq!(serialized_str.len(), 108);
}
}

View File

@ -49,6 +49,12 @@ pub(crate) fn codegen_block(
pub(crate) struct Codegen {
pub tokens: proc_macro2::TokenStream,
pub is_comptime: bool,
pub array_indexing: Option<ArrayIndexing>,
}
pub(crate) struct ArrayIndexing {
pub array: proc_macro2::TokenStream,
pub index: proc_macro2::TokenStream,
}
impl From<proc_macro2::TokenStream> for Codegen {
@ -56,6 +62,7 @@ impl From<proc_macro2::TokenStream> for Codegen {
Self {
tokens,
is_comptime: false,
array_indexing: None,
}
}
}
@ -65,6 +72,7 @@ impl Codegen {
Self {
tokens: tokens.into(),
is_comptime,
array_indexing: None,
}
}

View File

@ -25,6 +25,7 @@ pub(crate) fn codegen_expr(
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker),
_ => {
let mut array_indexing = None;
let tokens = match expr {
syn::Expr::Path(path) => {
return codegen_path_var(path, loop_level, variable_tracker)
@ -50,7 +51,11 @@ pub(crate) fn codegen_expr(
syn::Expr::MethodCall(call) => {
codegen_expr_method_call(call, loop_level, variable_tracker)
}
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_tracker),
syn::Expr::Index(index) => {
let codegen = codegen_index(index, loop_level, variable_tracker);
array_indexing = codegen.array_indexing;
codegen.tokens
}
syn::Expr::Array(array) => codegen_array_lit(array),
syn::Expr::Reference(reference) => {
codegen_ref(reference, loop_level, variable_tracker)
@ -67,7 +72,9 @@ pub(crate) fn codegen_expr(
}
};
Codegen::new(tokens, false)
let mut codegen = Codegen::new(tokens, false);
codegen.array_indexing = array_indexing;
codegen
}
}
}

View File

@ -8,7 +8,8 @@ pub(crate) fn codegen_binary(
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
let (lhs, is_comptime_lhs) = codegen_expr(&binary.left, loop_level, variable_tracker).split();
let lhs = codegen_expr(&binary.left, loop_level, variable_tracker);
let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing);
let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split();
if is_comptime_lhs && is_comptime_rhs {
@ -99,34 +100,94 @@ pub(crate) fn codegen_binary(
burn_cube::frontend::eq::expand(context, _lhs, _rhs)
}
},
syn::BinOp::AddAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
syn::BinOp::AddAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
syn::BinOp::SubAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
}
syn::BinOp::SubAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
syn::BinOp::MulAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
}
syn::BinOp::MulAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
syn::BinOp::DivAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
}
syn::BinOp::DivAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
}
syn::BinOp::And(_) => quote::quote! {
{

View File

@ -99,17 +99,25 @@ pub(crate) fn codegen_index(
index: &syn::ExprIndex,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
) -> Codegen {
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
let index = codegen_expr(&index.index, loop_level, variable_tracker);
quote::quote! {
let tokens = quote::quote! {
{
let _array = #array;
let _index = #index;
burn_cube::frontend::index::expand(context, _array, _index)
}
}
};
let mut codegen = Codegen::new(tokens, false);
codegen.array_indexing = Some(super::base::ArrayIndexing {
array: array.tokens,
index: index.tokens,
});
codegen
}
/// Codegen for assignation

View File

@ -1,5 +1,5 @@
use crate::ir::{Elem, Item, Visibility};
use crate::prelude::{ArrayExpand, CubeElem, KernelDefinition};
use crate::prelude::KernelDefinition;
use crate::KernelSettings;
use crate::{
frontend::{CubeContext, ExpandElement},
@ -60,21 +60,21 @@ impl KernelBuilder {
}
/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn output_array<T: CubeElem>(&mut self, item: Item) -> ArrayExpand<T> {
pub fn output_array(&mut self, item: Item) -> ExpandElement {
self.outputs.push(OutputInfo::Array { item });
let variable = self.context.output_array(self.num_output, item);
let variable = self.context.output(self.num_output, item);
self.num_output += 1;
variable
}
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn input_array<T: CubeElem>(&mut self, item: Item) -> ArrayExpand<T> {
pub fn input_array(&mut self, item: Item) -> ExpandElement {
self.inputs.push(InputInfo::Array {
item,
visibility: Visibility::Read,
});
let variable = self.context.input_array(self.num_input, item);
let variable = self.context.input(self.num_input, item);
self.num_input += 1;
variable
}

View File

@ -4,7 +4,7 @@ use alloc::rc::Rc;
use core::cell::RefCell;
use std::collections::HashMap;
use super::{ArrayExpand, CubeElem, SharedMemoryExpand};
use super::{CubeElem, SharedMemoryExpand};
#[derive(Default, Clone)]
pub struct VariablePool {
@ -117,10 +117,8 @@ impl CubeContext {
}
}
pub fn create_local_array<T: CubeElem>(&mut self, item: Item, size: u32) -> ArrayExpand<T> {
ArrayExpand {
val: ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size)),
}
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
}
/// Obtain the index-th input
@ -128,12 +126,6 @@ impl CubeContext {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
}
pub fn input_array<T: CubeElem>(&mut self, index: u16, item: Item) -> ArrayExpand<T> {
ArrayExpand {
val: ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item)),
}
}
/// Obtain the index-th output
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
let var = crate::ir::Variable::GlobalOutputArray(index, item);
@ -141,14 +133,6 @@ impl CubeContext {
ExpandElement::Plain(var)
}
pub fn output_array<T: CubeElem>(&mut self, index: u16, item: Item) -> ArrayExpand<T> {
let var = crate::ir::Variable::GlobalOutputArray(index, item);
self.scope.borrow_mut().write_global_custom(var);
ArrayExpand {
val: ExpandElement::Plain(var),
}
}
/// Obtain the index-th scalar
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))

View File

@ -8,44 +8,18 @@ use crate::{
};
use crate::{
frontend::{indexation::Index, CubeContext, CubeElem},
ir::Variable,
prelude::{assign, index, index_assign, Comptime},
};
use super::{ArgSettings, LaunchArg, TensorHandle, UInt};
use super::Init;
#[derive(Clone, Copy)]
pub struct Array<T: CubeType> {
_val: PhantomData<T>,
}
#[derive(Clone)]
pub struct ArrayExpand<T: CubeElem> {
pub val: <T as CubeType>::ExpandType,
}
impl<T: CubeElem> From<ArrayExpand<T>> for ExpandElement {
fn from(array_expand: ArrayExpand<T>) -> Self {
array_expand.val
}
}
impl<T: CubeElem> From<ArrayExpand<T>> for Variable {
fn from(array_expand: ArrayExpand<T>) -> Self {
*array_expand.val
}
}
impl<T: CubeElem> Init for ArrayExpand<T> {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
impl<T: CubeElem> CubeType for Array<T> {
type ExpandType = ArrayExpand<T>;
type ExpandType = ExpandElement;
}
impl<T: CubeElem + Clone> Array<T> {
@ -90,23 +64,24 @@ impl<T: CubeElem + Clone> Array<T> {
}
}
impl<T: CubeElem> ArrayExpand<T> {
impl ExpandElement {
pub fn to_vectorized_expand(
self,
context: &mut CubeContext,
vectorization_factor: UInt,
) -> <T as CubeType>::ExpandType {
) -> ExpandElement {
let factor = vectorization_factor.val;
let var = *self;
let mut new_var = context.create_local(Item::vectorized(
T::as_elem(),
vectorization_factor.val as u8,
var.item().elem(),
factor as u8,
));
if vectorization_factor.val == 1 {
let element = index::expand(context, self.val.clone(), 0);
let element = index::expand(context, self.clone(), 0u32);
assign::expand(context, element, new_var.clone());
} else {
for i in 0..factor {
let element = index::expand(context, self.val.clone(), i);
let element = index::expand(context, self.clone(), i);
new_var = index_assign::expand(context, new_var, i, element);
}
}
@ -124,11 +99,11 @@ impl<E: CubeType> Array<E> {
impl<C: CubeElem> LaunchArg for Array<C> {
type RuntimeArg<'a, R: Runtime> = ArrayHandle<'a, R>;
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ArrayExpand<C> {
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
builder.input_array(Item::vectorized(C::as_elem(), vectorization))
}
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ArrayExpand<C> {
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
builder.output_array(Item::vectorized(C::as_elem(), vectorization))
}
}

View File

@ -96,7 +96,10 @@ impl From<ExpandElement> for Variable {
impl Init for ExpandElement {
fn init(self, context: &mut CubeContext) -> Self {
init_expand(context, self, Operator::Assign)
match *self {
Variable::LocalArray(_, _, _, _) => self,
_ => init_expand(context, self, Operator::Assign),
}
}
}

View File

@ -26,6 +26,7 @@ pub trait Float:
+ Erf
+ Recip
+ core::ops::Index<UInt, Output = Self>
+ core::ops::IndexMut<UInt, Output = Self>
{
fn new(val: f32) -> Self;
fn new_expand(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
@ -109,6 +110,12 @@ macro_rules! impl_float {
}
}
impl core::ops::IndexMut<UInt> for $type {
fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
unexpanded!()
}
}
impl LaunchArg for $type {
type RuntimeArg<'a, R: Runtime> = $primitive;

View File

@ -1,3 +1,4 @@
use super::{CubeType, ExpandElement, Tensor, UInt};
pub trait Vectorized {

View File

@ -113,6 +113,90 @@ pub mod index {
impl_index!(SharedMemory);
}
pub mod add_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;
use self::ir::Operator;
use super::*;
pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Add);
}
}
pub mod sub_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;
use self::ir::Operator;
use super::*;
pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
}
}
pub mod mul_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;
use self::ir::Operator;
use super::*;
pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
}
}
pub mod div_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;
use self::ir::Operator;
use super::*;
pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Div);
}
}
pub mod add_assign_op {
use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};

View File

@ -203,3 +203,43 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization
output
}
pub fn array_assign_binary_op_expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
F: Fn(BinaryOperator) -> Operator,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
func: F,
) {
let array: ExpandElement = array.into();
let index: ExpandElement = index.into();
let value: ExpandElement = value.into();
let tmp = context.create_local(array.item());
let read = Operator::Index(BinaryOperator {
lhs: *array,
rhs: *index,
out: *tmp,
});
let calculate = func(BinaryOperator {
lhs: *tmp,
rhs: *value,
out: *tmp,
});
let write = Operator::IndexAssign(BinaryOperator {
lhs: *index,
rhs: *tmp,
out: *array,
});
context.register(read);
context.register(calculate);
context.register(write);
}

View File

@ -22,11 +22,21 @@ fn array_of_one_to_vectorized_variable<T: Numeric>() -> T {
array.to_vectorized(Comptime::new(UInt::new(1)))
}
#[cube]
fn array_add_assign_simple(mut array: Array<UInt>) {
array[UInt::new(1)] += UInt::new(1);
}
#[cube]
fn array_add_assign_expr(mut array: Array<UInt>) {
array[UInt::new(1) + UInt::new(5)] += UInt::new(1);
}
mod tests {
use super::*;
use burn_cube::{
cpa,
ir::{Item, Variable},
ir::{Elem, Item, Variable},
};
type ElemType = F32;
@ -39,6 +49,20 @@ mod tests {
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref_read_write()
)
}
#[test]
fn array_add_assign() {
let mut context = CubeContext::root();
let array = context.input(0, Item::new(Elem::UInt));
array_add_assign_simple_expand(&mut context, array);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_array_add_assign_simple()
);
}
@ -84,6 +108,37 @@ mod tests {
format!("{:?}", scope.operations)
}
#[test]
fn array_add_assign_expr() {
let mut context = CubeContext::root();
let array = context.input(0, Item::new(Elem::UInt));
array_add_assign_expr_expand(&mut context, array);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_array_add_assign_expr()
);
}
fn inline_macro_array_add_assign_simple() -> String {
let context = CubeContext::root();
let mut scope = context.into_scope();
let local = scope.create_local(Item::new(Elem::UInt));
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
let index = Variable::ConstantScalar(1., Elem::UInt);
let value = Variable::ConstantScalar(1., Elem::UInt);
cpa!(scope, local = array[index]);
cpa!(scope, local += value);
cpa!(scope, array[index] = local);
format!("{:?}", scope.operations)
}
fn inline_macro_ref_to_vectorized() -> String {
let context = CubeContext::root();
let scalar_item = Item::new(ElemType::as_elem());
@ -123,4 +178,24 @@ mod tests {
format!("{:?}", scope.operations)
}
fn inline_macro_array_add_assign_expr() -> String {
let context = CubeContext::root();
let mut scope = context.into_scope();
let index = scope.create_local(Item::new(Elem::UInt));
let local = scope.create_local(Item::new(Elem::UInt));
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
let const1 = Variable::ConstantScalar(1., Elem::UInt);
let const2 = Variable::ConstantScalar(5., Elem::UInt);
let value = Variable::ConstantScalar(1., Elem::UInt);
cpa!(scope, index = const1 + const2);
cpa!(scope, local = array[index]);
cpa!(scope, local += value);
cpa!(scope, array[index] = local);
format!("{:?}", scope.operations)
}
}

View File

@ -13,7 +13,7 @@ pub(crate) mod record;
pub(crate) mod shared;
/// Derive macro for the module.
#[proc_macro_derive(Module)]
#[proc_macro_derive(Module, attributes(module))]
pub fn module_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
module::derive_impl(&input)

View File

@ -2,7 +2,7 @@ use super::{display, record::ModuleRecordCodegen};
use crate::shared::generics::GenericsHelper;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use syn::{parse_quote, Attribute, Generics};
/// Basic trait to be implemented for Module generation.
pub(crate) trait ModuleCodegen {
@ -30,8 +30,8 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
let generics = GenericsParser::from_ast(&ast.generics);
let display_fn = display::display_fn(name);
let display_fn = display::display_fn(ast);
let attributes_fn = display::attributes_fn(ast);
let num_params_fn = codegen.gen_num_params();
let visit = codegen.gen_visit();
let map_mut = codegen.gen_map();
@ -54,7 +54,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
let generics_ty_inner_module = generics.inner_module_ty;
let gen = quote! {
let mut gen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {
type Record = #record_name #generics_ty_module;
@ -69,6 +69,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#collect_devices
#to_device
#fork
}
impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff
@ -82,6 +83,15 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#display_fn
}
impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {
#attributes_fn
fn num_params(&self) -> usize {
burn::module::Module::num_params(self)
}
}
impl #generics_module Clone for #name #generics_ty_module #generics_where_module {
#clone_fn
}
@ -89,13 +99,21 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#record_type
};
if !has_custom_display(&ast.attrs) {
gen.extend(quote! {
impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {
}
});
}
gen
}
// When there is no backend in the generic parameter, the type is considered as a constant.
pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (_generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};
@ -112,7 +130,10 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
let (generics_module, _, _) = generics_module.split_for_impl();
let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();
let gen = quote! {
let display_fn = display::display_fn(ast);
let attributes_fn = display::attributes_fn(ast);
let mut gen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
burn::constant!(module);
}
@ -121,8 +142,26 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
for #name #generics_ty #generics_where {
burn::constant!(ad_module, #name #generics_ty);
}
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {
#attributes_fn
}
};
if !has_custom_display(&ast.attrs) {
gen.extend(quote! {
impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {
}
});
}
gen
}
@ -159,22 +198,64 @@ impl GenericsParser {
#ident: burn::module::Module<B>
}
);
module.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplayDefault
}
);
module.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::AutodiffModule<B>
}
);
module_autodiff.add_predicate(
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>
}
);
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
}
);
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
}
);
generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::Module<B::InnerBackend>
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplayDefault
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
});
module.consts().into_iter().for_each(|ident| {
@ -188,3 +269,18 @@ impl GenericsParser {
}
}
}
fn has_custom_display(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
attr.path().is_ident("module")
&& attr
.parse_nested_meta(|meta| {
if meta.path.is_ident("custom_display") {
Ok(())
} else {
Err(meta.error("unsupported attribute"))
}
})
.is_ok()
})
}

View File

@ -1,11 +1,96 @@
use proc_macro2::Ident;
use quote::quote;
pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream {
pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
match &ast.data {
syn::Data::Struct(ref data_struct) => {
let fields = match &data_struct.fields {
syn::Fields::Named(ref named_fields) => {
named_fields.named.iter().collect::<Vec<_>>()
}
syn::Fields::Unit => Vec::new(),
_ => panic!("attributes_fn only supports structs with named or unit fields"),
};
let field_prints = fields.iter().map(|field| {
let field_name = &field.ident;
quote! { .add(stringify!(#field_name), &self.#field_name) }
});
let struct_name = &ast.ident;
quote! {
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
content
.set_top_level_type(&stringify!(#struct_name))
#(#field_prints)*
.optional()
}
}
}
syn::Data::Enum(ref data_enum) => {
let variant_prints = data_enum.variants.iter().map(|variant| {
let variant_name = &variant.ident;
match &variant.fields {
syn::Fields::Unit => {
quote! {
Self::#variant_name => {
content.add_formatted(&stringify!(#variant_name).to_string())
.optional()
}
}
}
syn::Fields::Named(ref named_fields) => {
let field_prints = named_fields.named.iter().map(|field| {
let field_name = &field.ident;
quote! { .add(stringify!(#field_name), &self.#field_name) }
});
let field_names = named_fields.named.iter().map(|field| {
let field_name = &field.ident;
quote! { #field_name }
});
quote! {
Self::#variant_name { #(#field_names),* } => {
content.set_top_level_type(&stringify!(#variant_name))
#(#field_prints)*
.optional()
}
}
}
syn::Fields::Unnamed(ref unnamed_fields) => {
let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {
syn::Ident::new(&format!("_{}", i), proc_macro2::Span::call_site())
});
let field_prints = field_names.clone().map(|field_name| {
quote! { .add(stringify!(#field_name), #field_name) }
});
quote! {
Self::#variant_name(#(#field_names),*) => {
content.set_top_level_type(&stringify!(#variant_name))
#(#field_prints)*
.optional()
}
}
}
}
});
quote! {
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
match self {
#(#variant_prints)*
}
}
}
}
_ => panic!("attributes_fn only supports structs and enums"),
}
}
pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
quote! {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}[num_params={}]", stringify!(#name), self.num_params())
let formatted = burn::module::ModuleDisplay::format(self, Default::default());
write!(f, "{}", formatted)
}
}
}

View File

@ -148,23 +148,25 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
let client = lhs.client.clone();
let lhs = match lhs.batch_swapped_with_row_col() {
true => into_contiguous(lhs),
false => lhs,
};
let rhs = match rhs.batch_swapped_with_row_col() {
true => into_contiguous(rhs),
false => rhs,
};
let lhs = into_contiguous(lhs);
let rhs = into_contiguous(rhs);
// let lhs = match lhs.batch_swapped_with_row_col() {
// true => into_contiguous(lhs),
// false => lhs,
// };
// let rhs = match rhs.batch_swapped_with_row_col() {
// true => into_contiguous(rhs),
// false => rhs,
// };
config.block_size_m = 16;
config.block_size_n = 16;
config.block_size_k = 16; // k must be <= both m and n
config.block_size_m = 64;
config.block_size_n = 64;
config.block_size_k = 32; // k must be <= both m and n
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
let vectorization_factor = 4;
let x = (config.block_size_m / vectorization_factor) as u32;
let y = (config.block_size_n / vectorization_factor) as u32;
let x = (config.block_size_m / 4) as u32;
let y = (config.block_size_n / 4) as u32;
let settings = KernelSettings::default()
.vectorize_input(0, vectorization_factor as u8)
.vectorize_input(1, vectorization_factor as u8)

View File

@ -9,11 +9,12 @@ use super::{
};
#[cube]
#[allow(unused_mut)]
pub(crate) fn compute_loop<F: Float>(
coordinates: Coordinates,
shared_lhs: SharedMemory<F>,
shared_rhs: SharedMemory<F>,
results: Array<F>,
mut results: Array<F>,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
@ -25,14 +26,13 @@ pub(crate) fn compute_loop<F: Float>(
let unit_row = coordinates.unit_row;
let unit_col = coordinates.unit_col;
let lhs_stride = block_size_m / tile_size;
let rhs_stride = block_size_n / tile_size;
for dot_index in range(0u32, block_size_k, unroll) {
let register_m = shared_lhs[(unit_col + dot_index) * Comptime::runtime(lhs_stride)];
let register_n = shared_rhs[(unit_row + dot_index) * Comptime::runtime(rhs_stride)];
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
/ Comptime::runtime(tile_size)];
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
/ Comptime::runtime(tile_size)];
tile_outer_product(register_m, register_n, results, config);
tile_outer_product::<F>(register_m, register_n, results, config);
}
}
@ -51,16 +51,19 @@ pub mod tests {
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_k = Comptime::map(config, |c| c.block_size_m);
let block_size_n = Comptime::map(config, |c| c.block_size_m);
let sm_size_lhs = block_size_m * block_size_k / tile_size;
let sm_size_rhs = block_size_n * block_size_k / tile_size;
// Shared memories are not launchable, so we launch with tensor and convert to shared memory
let sm_size_lhs = Comptime::map(config, |c| c.sm_size_lhs);
let mut shared_lhs =
SharedMemory::<F>::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size));
for i in range(0u32, lhs.len(), Comptime::new(false)) {
shared_lhs[i] = lhs[i];
}
let sm_size_rhs = Comptime::map(config, |c| c.sm_size_rhs);
let mut shared_rhs =
SharedMemory::<F>::vectorized(Comptime::get(sm_size_rhs), Comptime::get(tile_size));
for i in range(0u32, rhs.len(), Comptime::new(false)) {
@ -135,4 +138,62 @@ pub mod tests {
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn compute_loop_unit_offset_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let lhs = burn_tensor::Tensor::<B<R>, 2>::from_data(
burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..32, device)
.reshape([8, 4])
.float()
.transpose()
.into_data(),
device,
)
.into_primitive();
let rhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..32, device)
.reshape([4, 8])
.float()
.into_primitive();
let client = R::client(device);
let unit_row = 4;
let unit_col = 4;
let results = client.empty(tile_size * tile_size * core::mem::size_of::<f32>());
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(1, 1, 1))
.vectorize_input(0, tile_size as u8)
.vectorize_input(1, tile_size as u8);
let mut tiling2d_config = Tiling2dConfig::default();
tiling2d_config.block_size_m = 8;
tiling2d_config.block_size_k = 8;
tiling2d_config.block_size_n = 8;
let config = CubeTiling2dConfig::new(tiling2d_config, 4, 8, 4, tile_size);
compute_loop_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
unit_row,
unit_col,
ArrayHandle::new(&results, 1),
config,
);
let actual = client.read(results.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0,
1978.0, 1928.0, 2046.0, 2164.0, 2282.0,
];
assert_eq!(actual, expected);
}
}

View File

@ -25,19 +25,12 @@ pub struct CubeTiling2dConfig {
pub check_k_bounds: bool,
/// Bounds must be checked on rhs dimension
pub check_n_bounds: bool,
/// Shared memory size lhs: technically derivable from others, but needs comptime arithmetic
pub sm_size_lhs: UInt,
/// Shared memory size rhs: technically derivable from others, but needs comptime arithmetic
pub sm_size_rhs: UInt,
/// Tile size. Should correspond to vectorization of inputs/outputs/shared memory
pub tile_size: UInt,
}
impl CubeTiling2dConfig {
pub fn new(config: Tiling2dConfig, m: usize, k: usize, n: usize, tile_size: usize) -> Self {
let sm_size_lhs = (config.block_size_m / tile_size) * config.block_size_k;
let sm_size_rhs = (config.block_size_n / tile_size) * config.block_size_k;
CubeTiling2dConfig {
block_size_m: UInt::new(config.block_size_m as u32),
block_size_k: UInt::new(config.block_size_k as u32),
@ -46,8 +39,6 @@ impl CubeTiling2dConfig {
check_m_bounds: m % config.block_size_m != 0,
check_k_bounds: k % config.block_size_k != 0,
check_n_bounds: n % config.block_size_n != 0,
sm_size_lhs: UInt::new(sm_size_lhs as u32),
sm_size_rhs: UInt::new(sm_size_rhs as u32),
tile_size: UInt::new(tile_size as u32),
}
}

View File

@ -33,12 +33,12 @@ pub(crate) fn load_lhs_transposed<F: Float>(
let tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
load_tile(
load_tile::<F>(
lhs,
tile,
offset,
unit_col,
unit_row,
unit_col,
skip_row,
skip_col,
Comptime::map(config, |c| c.check_m_bounds),
@ -46,7 +46,7 @@ pub(crate) fn load_lhs_transposed<F: Float>(
config,
);
write_tile_transposed(
write_tile_transposed::<F>(
tile,
shared_lhs,
sm_position_base,
@ -83,7 +83,7 @@ pub(crate) fn load_rhs_plain<F: Float>(
let tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
load_tile(
load_tile::<F>(
rhs,
tile,
offset,
@ -96,7 +96,7 @@ pub(crate) fn load_rhs_plain<F: Float>(
config,
);
write_tile_plain(
write_tile_plain::<F>(
tile,
shared_rhs,
sm_position_base,
@ -128,16 +128,16 @@ fn load_tile<F: Float>(
if Comptime::get(check_vertical_bounds) {
let row = skip_row + load_row;
let dim_vertical = tensor.shape(tensor.rank() - UInt::new(2));
let dim_vertical = tensor.shape(rank - UInt::new(2));
if Comptime::get(check_horizontal_bounds) {
let col = skip_col + load_col;
let dim_horizontal = tensor.shape(tensor.rank() - UInt::new(1));
let dim_horizontal = tensor.shape(rank - UInt::new(1));
if col >= dim_horizontal {
read_zeros(tile, tile_size, unroll);
read_zeros::<F>(tile, tile_size, unroll);
} else {
read_partial(
read_partial::<F>(
tensor,
dim_vertical,
row,
@ -148,7 +148,7 @@ fn load_tile<F: Float>(
);
}
} else {
read_partial(
read_partial::<F>(
tensor,
dim_vertical,
row,
@ -161,11 +161,11 @@ fn load_tile<F: Float>(
} else {
if Comptime::get(check_horizontal_bounds) {
let col = skip_col + load_col;
let dim_horizontal = tensor.shape(tensor.rank() - UInt::new(1));
let dim_horizontal = tensor.shape(rank - UInt::new(1));
if col >= dim_horizontal {
read_zeros(tile, tile_size, unroll);
read_zeros::<F>(tile, tile_size, unroll);
} else {
read_whole(
read_whole::<F>(
tensor,
tensor_position_base,
tensor_stride,
@ -175,7 +175,7 @@ fn load_tile<F: Float>(
);
}
} else {
read_whole(
read_whole::<F>(
tensor,
tensor_position_base,
tensor_stride,
@ -221,12 +221,16 @@ fn write_tile_transposed<F: Float>(
shared_memory[sm_position_base] = tile[0];
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
let mut transposed = Array::<F>::new(Comptime::get(tile_size));
let mut transposed = F::vectorized(0., Comptime::get(tile_size));
for j in range(0u32, Comptime::get(tile_size), unroll) {
transposed[j] = tile[j][i];
}
let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization;
shared_memory[sm_position] = transposed.to_vectorized(tile_size);
shared_memory[sm_position] = transposed;
// let mut x = F::vectorized(0., Comptime::get(tile_size));
// x[UInt::new(0)] = F::cast_from(UNIT_POS);
// shared_memory[UNIT_POS] = x;
}
}
}
@ -286,7 +290,7 @@ pub mod tests {
#[cube(launch)]
#[allow(unused_mut)]
fn read_whole_test<F: Float>(tensor: Tensor<F>, mut tile: Array<F>, tile_size: Comptime<UInt>) {
read_whole(
read_whole::<F>(
tensor,
UInt::new(0),
tensor.stride(0),
@ -303,7 +307,7 @@ pub mod tests {
mut tile: Array<F>,
tile_size: Comptime<UInt>,
) {
read_partial(
read_partial::<F>(
tensor,
Comptime::runtime(tile_size),
UInt::new(2),
@ -317,7 +321,7 @@ pub mod tests {
#[cube(launch)]
#[allow(unused_mut)]
fn read_zeros_test<F: Float>(mut tile: Array<F>, tile_size: Comptime<UInt>) {
read_zeros(tile, tile_size, Comptime::new(true))
read_zeros::<F>(tile, tile_size, Comptime::new(true))
}
#[cube(launch)]
@ -361,7 +365,7 @@ pub mod tests {
let sm_stride = block_size_m;
let sm_size = Comptime::runtime(block_size_k * block_size_m);
let shared_memory = SharedMemory::vectorized(sm_size, Comptime::get(tile_size));
let shared_memory = SharedMemory::<F>::vectorized(sm_size, Comptime::get(tile_size));
if Comptime::get(transposed) {
write_tile_transposed(
@ -424,6 +428,42 @@ pub mod tests {
}
}
#[cube(launch)]
fn load_tensor_multiple_tiles_test<F: Float>(
tensor: Tensor<F>,
mut sm_out: Array<F>,
k: UInt,
config: Comptime<CubeTiling2dConfig>,
is_lhs: Comptime<bool>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let sm_size = block_size_k * block_size_m / tile_size;
let shared_memory =
SharedMemory::<F>::vectorized(Comptime::get(sm_size), Comptime::get(tile_size));
let unit_row = UInt::new(4) * UNIT_POS_X;
let unit_col = UInt::new(4) * UNIT_POS_Y;
let offset = UInt::new(0);
let coordinates = Coordinates {
unit_row,
unit_col,
skip_row: UInt::new(0),
skip_col: UInt::new(0),
};
if Comptime::get(is_lhs) {
load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config);
} else {
load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config);
}
for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) {
sm_out[i] = shared_memory[i];
}
}
/// Exported test
pub fn read_whole_unit_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
@ -817,6 +857,112 @@ pub mod tests {
assert_eq!(actual, expected);
}
/// Exported test
pub fn load_lhs_transposed_cube_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let lhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..64, device)
.reshape([8, 8])
.float()
.into_primitive();
let client = R::client(device);
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(2, 2, 1))
.vectorize_input(0, tile_size as u8)
.vectorize_output(0, tile_size as u8);
let mut tiling2d_config = Tiling2dConfig::default();
tiling2d_config.block_size_m = 8;
tiling2d_config.block_size_k = 8;
tiling2d_config.block_size_n = 8;
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 8, 8, 8, tile_size);
let sm_out = client.empty(
tiling2d_config.block_size_k
* tiling2d_config.block_size_m
* core::mem::size_of::<f32>(),
);
load_tensor_multiple_tiles_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
ArrayHandle::new(&sm_out, 64),
0,
config,
true,
);
let actual = client.read(sm_out.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
0.0, 8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0,
57.0, 2.0, 10.0, 18.0, 26.0, 34.0, 42.0, 50.0, 58.0, 3.0, 11.0, 19.0, 27.0, 35.0, 43.0,
51.0, 59.0, 4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 5.0, 13.0, 21.0, 29.0, 37.0,
45.0, 53.0, 61.0, 6.0, 14.0, 22.0, 30.0, 38.0, 46.0, 54.0, 62.0, 7.0, 15.0, 23.0, 31.0,
39.0, 47.0, 55.0, 63.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn load_lhs_transposed_offset_cube_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let lhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..128, device)
.reshape([8, 16])
.float()
.into_primitive();
let client = R::client(device);
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(2, 2, 1))
.vectorize_input(0, tile_size as u8)
.vectorize_output(0, tile_size as u8);
let mut tiling2d_config = Tiling2dConfig::default();
tiling2d_config.block_size_m = 8;
tiling2d_config.block_size_k = 8;
tiling2d_config.block_size_n = 8;
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 8, 8, 16, tile_size);
let sm_out = client.empty(
tiling2d_config.block_size_k
* tiling2d_config.block_size_m
* core::mem::size_of::<f32>(),
);
load_tensor_multiple_tiles_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
ArrayHandle::new(&sm_out, 64),
8,
config,
true,
);
let actual = client.read(sm_out.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
8.0, 24.0, 40.0, 56.0, 72.0, 88.0, 104.0, 120.0, 9.0, 25.0, 41.0, 57.0, 73.0, 89.0,
105.0, 121.0, 10.0, 26.0, 42.0, 58.0, 74.0, 90.0, 106.0, 122.0, 11.0, 27.0, 43.0, 59.0,
75.0, 91.0, 107.0, 123.0, 12.0, 28.0, 44.0, 60.0, 76.0, 92.0, 108.0, 124.0, 13.0, 29.0,
45.0, 61.0, 77.0, 93.0, 109.0, 125.0, 14.0, 30.0, 46.0, 62.0, 78.0, 94.0, 110.0, 126.0,
15.0, 31.0, 47.0, 63.0, 79.0, 95.0, 111.0, 127.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn load_rhs_plain_unit_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
@ -868,4 +1014,110 @@ pub mod tests {
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn load_rhs_plain_cube_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let rhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..64, device)
.reshape([8, 8])
.float()
.into_primitive();
let client = R::client(device);
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(2, 2, 1))
.vectorize_input(0, tile_size as u8)
.vectorize_output(0, tile_size as u8);
let mut tiling2d_config = Tiling2dConfig::default();
tiling2d_config.block_size_m = 8;
tiling2d_config.block_size_k = 8;
tiling2d_config.block_size_n = 8;
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 8, 8, 8, tile_size);
let sm_out = client.empty(
tiling2d_config.block_size_k
* tiling2d_config.block_size_m
* core::mem::size_of::<f32>(),
);
load_tensor_multiple_tiles_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
ArrayHandle::new(&sm_out, 64),
0,
config,
false,
);
let actual = client.read(sm_out.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
58.0, 59.0, 60.0, 61.0, 62.0, 63.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn load_rhs_plain_cube_offset_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let rhs = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..128, device)
.reshape([16, 8])
.float()
.into_primitive();
let client = R::client(device);
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(2, 2, 1))
.vectorize_input(0, tile_size as u8)
.vectorize_output(0, tile_size as u8);
let mut tiling2d_config = Tiling2dConfig::default();
tiling2d_config.block_size_m = 8;
tiling2d_config.block_size_k = 8;
tiling2d_config.block_size_n = 8;
let config = CubeTiling2dConfig::new(tiling2d_config.clone(), 16, 16, 8, tile_size);
let sm_out = client.empty(
tiling2d_config.block_size_k
* tiling2d_config.block_size_m
* core::mem::size_of::<f32>(),
);
load_tensor_multiple_tiles_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
ArrayHandle::new(&sm_out, 64),
8,
config,
false,
);
let actual = client.read(sm_out.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0,
78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0,
92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0,
105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0,
117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0,
];
assert_eq!(actual, expected);
}
}

View File

@ -14,18 +14,15 @@ pub(crate) fn tile_outer_product<F: Float>(
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll);
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
if Comptime::get(is_scalar) {
// works
results[0] = results[0] + register_m * register_n;
// doesnt work
results[0] += register_m * register_n;
} else {
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
let res_pos_base = res_idx_m * Comptime::runtime(tile_size);
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
let mul = register_m[res_idx_m] * register_n[res_idx_n];
// results[res_pos_base + res_idx_n] += mul;
results[res_pos_base + res_idx_n] = results[res_pos_base + res_idx_n] + mul;
results[res_pos_base + res_idx_n] += mul;
}
}
}
@ -57,7 +54,7 @@ pub mod tests {
) {
results[i] = F::new(0.);
}
tile_outer_product(register_m, register_n, results, config)
tile_outer_product::<F>(register_m, register_n, results, config)
}
fn test_case_config(tile_size: usize) -> CubeTiling2dConfig {
@ -97,6 +94,38 @@ pub mod tests {
assert_eq!(actual, expected);
}
/// Exported test
pub fn tile_outer_product_vectorized_unit_test_2<R: Runtime>(device: &R::Device) {
let client = R::client(device);
let register_m = client.create(f32::as_bytes(&[16., 20., 24., 28.]));
let register_n = client.create(f32::as_bytes(&[4., 5., 6., 7.]));
let results = client.empty(16 * core::mem::size_of::<f32>());
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default().cube_dim(CubeDim::new(1, 1, 1));
let config = test_case_config(4);
tile_outer_product_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
ArrayHandle::new(&register_m, 4),
ArrayHandle::new(&register_n, 4),
ArrayHandle::new(&results, 16),
config,
);
let actual = client.read(results.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0,
140.0, 168.0, 196.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn tile_outer_product_scalar_unit_test<R: Runtime>(device: &R::Device) {
let client = R::client(device);

View File

@ -22,24 +22,24 @@ pub(crate) fn tiling2d_core<F: Float>(
config: Comptime<CubeTiling2dConfig>,
) {
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let results = init_results(config);
let results = init_results::<F>(config);
let n_loops = calculate_n_loops::<F>(lhs.shape(lhs.rank() - UInt::new(1)), config);
for k in range(0u32, n_loops, Comptime::new(false)) {
let k = k * Comptime::runtime(block_size_k);
load_lhs_transposed(lhs, coordinates, k, offsets.lhs, shared.lhs, config);
load_rhs_plain(rhs, coordinates, k, offsets.rhs, shared.rhs, config);
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config);
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config);
sync_units();
compute_loop(coordinates, shared.lhs, shared.rhs, results, config);
compute_loop::<F>(coordinates, shared.lhs, shared.rhs, results, config);
sync_units();
}
write_to_output(out, results, coordinates, offsets.out, config);
write_to_output::<F>(out, results, coordinates, offsets.out, config);
}
#[cube]

View File

@ -20,7 +20,7 @@ pub(crate) fn write_to_output<F: Float>(
let col = coordinates.skip_col + coordinates.unit_col;
let rank = out.rank();
let out_stride_row = out.stride(rank - UInt::new(2)) / Comptime::runtime(tile_size);
let out_stride_row = out.stride(rank - UInt::new(2));
if Comptime::get(check_m_bounds) {
let dim_m = out.shape(rank - UInt::new(2));
@ -28,7 +28,7 @@ pub(crate) fn write_to_output<F: Float>(
let dim_n = out.shape(rank - UInt::new(1));
if row < dim_m && col < dim_n {
let num_writes = UInt::min(dim_m - row, Comptime::runtime(tile_size));
write_results_to_output_partial(
write_results_to_output_partial::<F>(
out,
results,
row,
@ -42,7 +42,7 @@ pub(crate) fn write_to_output<F: Float>(
} else {
if row < dim_m {
let num_writes = UInt::min(dim_m - row, Comptime::runtime(tile_size));
write_results_to_output_partial(
write_results_to_output_partial::<F>(
out,
results,
row,
@ -58,7 +58,7 @@ pub(crate) fn write_to_output<F: Float>(
if Comptime::get(check_n_bounds) {
let dim_n = out.shape(rank - UInt::new(1));
if col < dim_n {
write_results_to_output(
write_results_to_output::<F>(
out,
results,
row,
@ -69,7 +69,7 @@ pub(crate) fn write_to_output<F: Float>(
);
}
} else {
write_results_to_output(
write_results_to_output::<F>(
out,
results,
row,
@ -93,14 +93,15 @@ fn write_results_to_output<F: Float>(
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let is_scalar = Comptime::map(tile_size, |t| t.val == 1);
let sm_is_scalar = Comptime::map(tile_size, |t| t.val == 1);
let unroll = Comptime::map(config, |c| c.unroll);
if Comptime::get(is_scalar) {
out[row * out_stride_row + col + offset_output] = results[0];
let vectorization_factor = Comptime::runtime(Comptime::vectorization(out));
if Comptime::get(sm_is_scalar) {
out[(row * out_stride_row + col + offset_output) / vectorization_factor] = results[0];
} else {
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
write_results_inner_loop(
write_results_inner_loop::<F>(
out,
results,
res_idx_m,
@ -126,13 +127,14 @@ fn write_results_to_output_partial<F: Float>(
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let is_scalar = Comptime::map(tile_size, |t| t.val == 1);
let sm_is_scalar = Comptime::map(tile_size, |t| t.val == 1);
if Comptime::get(is_scalar) {
out[row * out_stride_row + col + batch_offset / Comptime::runtime(tile_size)] = results[0];
let vectorization_factor = Comptime::runtime(Comptime::vectorization(out));
if Comptime::get(sm_is_scalar) {
out[(row * out_stride_row + col + batch_offset) / vectorization_factor] = results[0];
} else {
for res_idx_m in range(0u32, num_writes, Comptime::new(false)) {
write_results_inner_loop(
write_results_inner_loop::<F>(
out,
results,
res_idx_m,
@ -161,16 +163,16 @@ fn write_results_inner_loop<F: Float>(
let unroll = Comptime::map(config, |c| c.unroll);
let results_pos_m = res_idx_m * Comptime::runtime(tile_size);
let out_position =
(row + res_idx_m) * out_stride_row + col / Comptime::runtime(tile_size) + batch_offset;
let out_position = (row + res_idx_m) * out_stride_row + col + batch_offset;
// Reinterpreting results as vectorized array
let mut array = Array::<F>::new(Comptime::get(tile_size));
let mut output = F::vectorized(0., Comptime::get(tile_size));
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
array[res_idx_n] = results[results_pos_m + res_idx_n];
output[res_idx_n] = results[results_pos_m + res_idx_n];
}
out[out_position] = array.to_vectorized(tile_size);
let vectorization_factor = Comptime::runtime(Comptime::vectorization(out));
out[out_position / vectorization_factor] = output;
}
#[cfg(feature = "export_tests")]
@ -184,9 +186,8 @@ pub mod tests {
results: Array<F>,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let out_stride_row = out.stride(out.rank() - UInt::new(2)) / Comptime::runtime(tile_size);
write_results_inner_loop(
let out_stride_row = out.stride(out.rank() - UInt::new(2));
write_results_inner_loop::<F>(
out,
results,
UInt::new(2),
@ -204,9 +205,8 @@ pub mod tests {
results: Array<F>,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let out_stride_row = out.stride(out.rank() - UInt::new(2)) / Comptime::runtime(tile_size);
write_results_to_output(
let out_stride_row = out.stride(out.rank() - UInt::new(2));
write_results_to_output::<F>(
out,
results,
UInt::new(4),
@ -223,9 +223,8 @@ pub mod tests {
results: Array<F>,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let out_stride_row = out.stride(out.rank() - UInt::new(2)) / Comptime::runtime(tile_size);
write_results_to_output_partial(
let out_stride_row = out.stride(out.rank() - UInt::new(2));
write_results_to_output_partial::<F>(
out,
results,
UInt::new(4),
@ -249,7 +248,7 @@ pub mod tests {
skip_row: UInt::new(0),
skip_col: UInt::new(0),
};
write_to_output(out, results, coordinates, UInt::new(0), config);
write_to_output::<F>(out, results, coordinates, UInt::new(0), config);
}
#[cube(launch)]
@ -264,7 +263,7 @@ pub mod tests {
skip_row: UInt::new(0),
skip_col: UInt::new(0),
};
write_to_output(out, results, coordinates, UInt::new(0), config);
write_to_output::<F>(out, results, coordinates, UInt::new(0), config);
}
/// Exported test

View File

@ -478,7 +478,7 @@ mod tests {
#[test]
pub fn medium() {
test_with_params(16, 16, 16, 1, 1);
test_with_params(17, 16, 16, 1, 1);
}
#[test]

View File

@ -15,6 +15,13 @@ mod tests {
)
}
#[test]
pub fn tiling2d_matmul_outer_product_vectorized_test_2() {
outer_product_tests::tile_outer_product_vectorized_unit_test_2::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn tiling2d_matmul_outer_product_scalar_test() {
outer_product_tests::tile_outer_product_scalar_unit_test::<TestRuntime>(&Default::default())
@ -25,6 +32,11 @@ mod tests {
compute_loop_tests::compute_loop_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn compute_loop_unit_offset_test() {
compute_loop_tests::compute_loop_unit_offset_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn tiling2d_matmul_read_whole_vectorized_test() {
load_shared_memory_tests::read_whole_unit_test::<TestRuntime>(&Default::default())
@ -74,11 +86,33 @@ mod tests {
load_shared_memory_tests::load_lhs_transposed_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_lhs_transposed_cube_test() {
load_shared_memory_tests::load_lhs_transposed_cube_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_lhs_transposed_offset_cube_test() {
load_shared_memory_tests::load_lhs_transposed_offset_cube_test::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn load_rhs_plain_unit_test() {
load_shared_memory_tests::load_rhs_plain_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_rhs_plain_cube_test() {
load_shared_memory_tests::load_rhs_plain_cube_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_rhs_plain_cube_offset_test() {
load_shared_memory_tests::load_rhs_plain_cube_offset_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn write_results_inner_loop_unit_test() {
write_output_tests::write_results_inner_loop_unit_test::<TestRuntime>(&Default::default())

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(matmul)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
use burn_tensor::{Data, Int, Tensor};
#[test]
fn test_matmul_d2() {
@ -111,14 +111,102 @@ mod tests {
#[test]
fn test_matmul_trivial() {
let device = Default::default();
let tensor_1 = TestTensor::from_floats([[3., 4.]], &device);
let tensor_2 = TestTensor::from_floats([[4.], [3.]], &device);
// let tensor_3 = tensor_1.matmul(tensor_2);
let tensor_3 = tensor_2.matmul(tensor_1);
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..16, &device)
.reshape([4, 4])
.float();
// assert_eq!(tensor_3.into_data(), Data::from([[24.]]));
assert_eq!(tensor_3.into_data(), Data::from([[12., 16.], [9., 12.]]));
let tensor_3 = tensor_1.clone().matmul(tensor_1);
assert_eq!(
tensor_3.into_data(),
Data::from([
[56., 62., 68., 74.],
[152., 174., 196., 218.],
[248., 286., 324., 362.],
[344., 398., 452., 506.]
])
);
}
#[test]
fn test_matmul_trivial_transposed() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..16, &device)
.reshape([4, 4])
.float();
let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());
assert_eq!(
tensor_3.into_data(),
Data::from([
[14., 38., 62., 86.],
[38., 126., 214., 302.],
[62., 214., 366., 518.],
[86., 302., 518., 734.]
])
);
}
#[test]
fn test_matmul_4_8() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..32, &device)
.reshape([4, 8])
.float();
let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());
assert_eq!(
tensor_3.into_data(),
Data::from([
[140., 364., 588., 812.],
[364., 1100., 1836., 2572.],
[588., 1836., 3084., 4332.],
[812., 2572., 4332., 6092.]
])
);
}
#[test]
fn test_matmul_8_4() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 1, Int>::arange(0..32, &device)
.reshape([8, 4])
.float();
let tensor_2 = Tensor::<TestBackend, 1, Int>::arange(0..32, &device)
.reshape([4, 8])
.float();
let tensor_3 = tensor_1.clone().matmul(tensor_2);
assert_eq!(
tensor_3.into_data(),
Data::from([
[112., 118., 124., 130., 136., 142., 148., 154.],
[304., 326., 348., 370., 392., 414., 436., 458.],
[496., 534., 572., 610., 648., 686., 724., 762.],
[688., 742., 796., 850., 904., 958., 1012., 1066.],
[880., 950., 1020., 1090., 1160., 1230., 1300., 1370.],
[1072., 1158., 1244., 1330., 1416., 1502., 1588., 1674.],
[1264., 1366., 1468., 1570., 1672., 1774., 1876., 1978.],
[1456., 1574., 1692., 1810., 1928., 2046., 2164., 2282.]
],)
);
// [
// [112.0, 118.0, 124.0, 130.0, 880.0, 950.0, 1020.0, 1090.0],
// [304.0, 326.0, 348.0, 370.0, 1072.0, 1158.0, 1244.0, 1330.0],
// [496.0, 534.0, 572.0, 610.0, 1264.0, 1366.0, 1468.0, 1570.0],
// [688.0, 742.0, 796.0, 850.0, 1456.0, 1574.0, 1692.0, 1810.0],
// [136.0, 142.0, 148.0, 154.0, 1160.0, 1230.0, 1300.0, 1370.0],
// [392.0, 414.0, 436.0, 458.0, 1416.0, 1502.0, 1588.0, 1674.0],
// [648.0, 686.0, 724.0, 762.0, 1672.0, 1774.0, 1876.0, 1978.0],
// [904.0, 958.0, 1012.0, 1066.0, 1928.0, 2046.0, 2164.0, 2282.0],
// ]
}
#[test]

View File

@ -151,7 +151,7 @@ impl Display for LearnerSummary {
)?;
if let Some(model) = &self.model {
writeln!(f, "Model: {model}")?;
writeln!(f, "Model:\n{model}")?;
}
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;

View File

@ -88,7 +88,7 @@ where
}
let compile = kernel.compile();
println!("{}", compile.source);
// println!("{}", compile.source);
let pipeline = self.compile_source(&compile.source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());