Upgrade to candle 0.4.1 (#1382)

* Fix python main entrypoint in book example

* Remove candle windows safeguards (#1178)

* Bump candle-core from 0.3.3 to 0.4.1

* Remove windows current known issue
This commit is contained in:
Guillaume Lagrange 2024-02-29 12:29:11 -05:00 committed by GitHub
parent 40bf3927f0
commit 4efc683df4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 27 additions and 49 deletions

13
Cargo.lock generated
View File

@ -691,9 +691,9 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.3.3"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db8659ea87ee8197d2fc627348916cce0561330ee7ae3874e771691d3cecb2f"
checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
dependencies = [
"accelerate-src",
"byteorder",
@ -718,18 +718,18 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.3.3"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d80cdd5f1cc60d30ba61353cdba5accd0fbc4d4ef2fe707fcb5179a9821adbea"
checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-metal-kernels"
version = "0.3.3"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52567e7a314ae0c59db5fbd4874ce461d99fa22adb22ddf7cf296b4d97035b40"
checksum = "b20d6c0d49121e2709ed9faa958ba915ea59526036bcf27558817d1452a4ff09"
dependencies = [
"metal",
"once_cell",
@ -3521,7 +3521,6 @@ dependencies = [
"burn",
"burn-import",
"burn-ndarray",
"cfg-if",
"float-cmp",
"serde",
]

View File

@ -25,7 +25,7 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies]
async-trait = "0.1.74"
bytemuck = "1.14"
candle-core = { version = "0.3.3" }
candle-core = { version = "0.4.1" }
clap = { version = "4.5.1", features = ["derive"] }
console_error_panic_hook = "0.1.7"
csv = "1.3.0"

View File

@ -30,7 +30,7 @@ class Net(nn.Module):
x = self.conv2(x)
return x
def main():
if __name__ == "__main__":
torch.manual_seed(42) # To make it reproducible
model = Net().to(torch.device("cpu"))
model_weights = model.state_dict()
@ -254,5 +254,4 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
## Current known issues
1. [Candle's pickle library does not currently function on Windows due to a Candle bug](https://github.com/tracel-ai/burn/issues/1178).
2. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).

View File

@ -83,10 +83,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> {
assert!(
options.groups == 1,
"Candle does not support groups in transposed convolutions"
);
let conv_transpose = x
.tensor
.conv_transpose1d(
@ -95,6 +91,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
options.padding_out[0],
options.stride[0],
options.dilation[0],
options.groups,
)
.unwrap();
CandleTensor::new(match bias {

View File

@ -10,7 +10,6 @@ burn-ndarray = { path = "../../burn-ndarray" }
serde = { workspace = true }
float-cmp = { workspace = true }
burn-import = { path = "../", features = ["pytorch"] }
cfg-if = "1.0.0"
[build-dependencies]

View File

@ -1,23 +1,17 @@
cfg_if::cfg_if! {
if #[cfg(not(target_os = "windows"))] {
// The crate is not supported on Windows because of Candle's pt bug on Windows
// (see https://github.com/huggingface/candle/issues/1454).
mod batch_norm;
mod boolean;
mod buffer;
mod complex_nested;
mod config;
mod conv1d;
mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod embedding;
mod group_norm;
mod integer;
mod key_remap;
mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;
}
}
mod batch_norm;
mod boolean;
mod buffer;
mod complex_nested;
mod config;
mod conv1d;
mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod embedding;
mod group_norm;
mod integer;
mod key_remap;
mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;

View File

@ -13,16 +13,6 @@ use burn_import::pytorch::PyTorchFileRecorder;
type B = NdArray<f32>;
fn main() {
if cfg!(target_os = "windows") {
println!(
"{}",
"cargo:warning=The crate is not supported on Windows because of ".to_owned()
+ "Candle's pt bug on Windows "
+ "(see https://github.com/huggingface/candle/issues/1454)."
);
std::process::exit(0);
}
let device = Default::default();
// Load PyTorch weights into a model record.