Upgrade to candle 0.4.1 (#1382)

* Fix python main entrypoint in book example

* Remove candle windows safeguards (#1178)

* Bump candle-core from 0.3.3 to 0.4.1

* Remove windows current known issue
This commit is contained in:
Guillaume Lagrange 2024-02-29 12:29:11 -05:00 committed by GitHub
parent 40bf3927f0
commit 4efc683df4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 27 additions and 49 deletions

13
Cargo.lock generated
View File

@ -691,9 +691,9 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.3.3" version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db8659ea87ee8197d2fc627348916cce0561330ee7ae3874e771691d3cecb2f" checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
dependencies = [ dependencies = [
"accelerate-src", "accelerate-src",
"byteorder", "byteorder",
@ -718,18 +718,18 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.3.3" version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d80cdd5f1cc60d30ba61353cdba5accd0fbc4d4ef2fe707fcb5179a9821adbea" checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
dependencies = [ dependencies = [
"bindgen_cuda", "bindgen_cuda",
] ]
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.3.3" version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52567e7a314ae0c59db5fbd4874ce461d99fa22adb22ddf7cf296b4d97035b40" checksum = "b20d6c0d49121e2709ed9faa958ba915ea59526036bcf27558817d1452a4ff09"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
@ -3521,7 +3521,6 @@ dependencies = [
"burn", "burn",
"burn-import", "burn-import",
"burn-ndarray", "burn-ndarray",
"cfg-if",
"float-cmp", "float-cmp",
"serde", "serde",
] ]

View File

@ -25,7 +25,7 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies] [workspace.dependencies]
async-trait = "0.1.74" async-trait = "0.1.74"
bytemuck = "1.14" bytemuck = "1.14"
candle-core = { version = "0.3.3" } candle-core = { version = "0.4.1" }
clap = { version = "4.5.1", features = ["derive"] } clap = { version = "4.5.1", features = ["derive"] }
console_error_panic_hook = "0.1.7" console_error_panic_hook = "0.1.7"
csv = "1.3.0" csv = "1.3.0"

View File

@ -30,7 +30,7 @@ class Net(nn.Module):
x = self.conv2(x) x = self.conv2(x)
return x return x
def main(): if __name__ == "__main__":
torch.manual_seed(42) # To make it reproducible torch.manual_seed(42) # To make it reproducible
model = Net().to(torch.device("cpu")) model = Net().to(torch.device("cpu"))
model_weights = model.state_dict() model_weights = model.state_dict()
@ -254,5 +254,4 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
## Current known issues ## Current known issues
1. [Candle's pickle library does not currently function on Windows due to a Candle bug](https://github.com/tracel-ai/burn/issues/1178). 1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
2. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).

View File

@ -83,10 +83,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
bias: Option<FloatTensor<Self, 1>>, bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>, options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> { ) -> FloatTensor<Self, 3> {
assert!(
options.groups == 1,
"Candle does not support groups in transposed convolutions"
);
let conv_transpose = x let conv_transpose = x
.tensor .tensor
.conv_transpose1d( .conv_transpose1d(
@ -95,6 +91,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
options.padding_out[0], options.padding_out[0],
options.stride[0], options.stride[0],
options.dilation[0], options.dilation[0],
options.groups,
) )
.unwrap(); .unwrap();
CandleTensor::new(match bias { CandleTensor::new(match bias {

View File

@ -10,7 +10,6 @@ burn-ndarray = { path = "../../burn-ndarray" }
serde = { workspace = true } serde = { workspace = true }
float-cmp = { workspace = true } float-cmp = { workspace = true }
burn-import = { path = "../", features = ["pytorch"] } burn-import = { path = "../", features = ["pytorch"] }
cfg-if = "1.0.0"
[build-dependencies] [build-dependencies]

View File

@ -1,23 +1,17 @@
cfg_if::cfg_if! { mod batch_norm;
if #[cfg(not(target_os = "windows"))] { mod boolean;
// The crate is not supported on Windows because of Candle's pt bug on Windows mod buffer;
// (see https://github.com/huggingface/candle/issues/1454). mod complex_nested;
mod batch_norm; mod config;
mod boolean; mod conv1d;
mod buffer; mod conv2d;
mod complex_nested; mod conv_transpose1d;
mod config; mod conv_transpose2d;
mod conv1d; mod embedding;
mod conv2d; mod group_norm;
mod conv_transpose1d; mod integer;
mod conv_transpose2d; mod key_remap;
mod embedding; mod key_remap_chained;
mod group_norm; mod layer_norm;
mod integer; mod linear;
mod key_remap; mod missing_module_field;
mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;
}
}

View File

@ -13,16 +13,6 @@ use burn_import::pytorch::PyTorchFileRecorder;
type B = NdArray<f32>; type B = NdArray<f32>;
fn main() { fn main() {
if cfg!(target_os = "windows") {
println!(
"{}",
"cargo:warning=The crate is not supported on Windows because of ".to_owned()
+ "Candle's pt bug on Windows "
+ "(see https://github.com/huggingface/candle/issues/1454)."
);
std::process::exit(0);
}
let device = Default::default(); let device = Default::default();
// Load PyTorch weights into a model record. // Load PyTorch weights into a model record.