From 4efc683df45811e6362baf0c156d3f7649f5a3b2 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 29 Feb 2024 12:29:11 -0500 Subject: [PATCH] Upgrade to candle 0.4.1 (#1382) * Fix python main entrypoint in book example * Remove candle windows safeguards (#1178) * Bump candle-core from 0.3.3 to 0.4.1 * Remove windows current known issue --- Cargo.lock | 13 +++--- Cargo.toml | 2 +- burn-book/src/import/pytorch-model.md | 5 +-- crates/burn-candle/src/ops/module.rs | 5 +-- crates/burn-import/pytorch-tests/Cargo.toml | 1 - crates/burn-import/pytorch-tests/tests/mod.rs | 40 ++++++++----------- examples/pytorch-import/build.rs | 10 ----- 7 files changed, 27 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1026aebc0..e68738c00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -691,9 +691,9 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.3.3" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db8659ea87ee8197d2fc627348916cce0561330ee7ae3874e771691d3cecb2f" +checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00" dependencies = [ "accelerate-src", "byteorder", @@ -718,18 +718,18 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.3.3" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d80cdd5f1cc60d30ba61353cdba5accd0fbc4d4ef2fe707fcb5179a9821adbea" +checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.3.3" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52567e7a314ae0c59db5fbd4874ce461d99fa22adb22ddf7cf296b4d97035b40" +checksum = "b20d6c0d49121e2709ed9faa958ba915ea59526036bcf27558817d1452a4ff09" dependencies = [ "metal", "once_cell", @@ -3521,7 +3521,6 @@ dependencies = [ "burn", "burn-import", "burn-ndarray", - "cfg-if", "float-cmp", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index 854813a97..b9eb345b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ license = "MIT OR Apache-2.0" [workspace.dependencies] async-trait = "0.1.74" bytemuck = "1.14" -candle-core = { version = "0.3.3" } +candle-core = { version = "0.4.1" } clap = { version = "4.5.1", features = ["derive"] } console_error_panic_hook = "0.1.7" csv = "1.3.0" diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 2fd6a1aac..a2356a16e 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -30,7 +30,7 @@ class Net(nn.Module): x = self.conv2(x) return x -def main(): +if __name__ == "__main__": torch.manual_seed(42) # To make it reproducible model = Net().to(torch.device("cpu")) model_weights = model.state_dict() @@ -254,5 +254,4 @@ defining the encoder in Burn, allowing the loading of its weights while excludin ## Current known issues -1. [Candle's pickle library does not currently function on Windows due to a Candle bug](https://github.com/tracel-ai/burn/issues/1178). -2. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179). +1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179). diff --git a/crates/burn-candle/src/ops/module.rs b/crates/burn-candle/src/ops/module.rs index b0b1c2264..7639ee0be 100644 --- a/crates/burn-candle/src/ops/module.rs +++ b/crates/burn-candle/src/ops/module.rs @@ -83,10 +83,6 @@ impl ModuleOps for Candle>, options: ConvTransposeOptions<1>, ) -> FloatTensor { - assert!( - options.groups == 1, - "Candle does not support groups in transposed convolutions" - ); let conv_transpose = x .tensor .conv_transpose1d( @@ -95,6 +91,7 @@ impl ModuleOps for Candle; fn main() { - if cfg!(target_os = "windows") { - println!( - "{}", - "cargo:warning=The crate is not supported on Windows because of ".to_owned() - + "Candle's pt bug on Windows " - + "(see https://github.com/huggingface/candle/issues/1454)." - ); - std::process::exit(0); - } - let device = Default::default(); // Load PyTorch weights into a model record.