mirror of https://github.com/tracel-ai/burn.git
Upgrade to candle 0.4.1 (#1382)
* Fix python main entrypoint in book example * Remove candle windows safeguards (#1178) * Bump candle-core from 0.3.3 to 0.4.1 * Remove windows current known issue
This commit is contained in:
parent
40bf3927f0
commit
4efc683df4
|
@ -691,9 +691,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-core"
|
name = "candle-core"
|
||||||
version = "0.3.3"
|
version = "0.4.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6db8659ea87ee8197d2fc627348916cce0561330ee7ae3874e771691d3cecb2f"
|
checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate-src",
|
"accelerate-src",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
|
@ -718,18 +718,18 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.3.3"
|
version = "0.4.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d80cdd5f1cc60d30ba61353cdba5accd0fbc4d4ef2fe707fcb5179a9821adbea"
|
checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen_cuda",
|
"bindgen_cuda",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.3.3"
|
version = "0.4.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "52567e7a314ae0c59db5fbd4874ce461d99fa22adb22ddf7cf296b4d97035b40"
|
checksum = "b20d6c0d49121e2709ed9faa958ba915ea59526036bcf27558817d1452a4ff09"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"metal",
|
"metal",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
@ -3521,7 +3521,6 @@ dependencies = [
|
||||||
"burn",
|
"burn",
|
||||||
"burn-import",
|
"burn-import",
|
||||||
"burn-ndarray",
|
"burn-ndarray",
|
||||||
"cfg-if",
|
|
||||||
"float-cmp",
|
"float-cmp",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
|
@ -25,7 +25,7 @@ license = "MIT OR Apache-2.0"
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1.74"
|
||||||
bytemuck = "1.14"
|
bytemuck = "1.14"
|
||||||
candle-core = { version = "0.3.3" }
|
candle-core = { version = "0.4.1" }
|
||||||
clap = { version = "4.5.1", features = ["derive"] }
|
clap = { version = "4.5.1", features = ["derive"] }
|
||||||
console_error_panic_hook = "0.1.7"
|
console_error_panic_hook = "0.1.7"
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
|
|
|
@ -30,7 +30,7 @@ class Net(nn.Module):
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
torch.manual_seed(42) # To make it reproducible
|
torch.manual_seed(42) # To make it reproducible
|
||||||
model = Net().to(torch.device("cpu"))
|
model = Net().to(torch.device("cpu"))
|
||||||
model_weights = model.state_dict()
|
model_weights = model.state_dict()
|
||||||
|
@ -254,5 +254,4 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
|
||||||
|
|
||||||
## Current known issues
|
## Current known issues
|
||||||
|
|
||||||
1. [Candle's pickle library does not currently function on Windows due to a Candle bug](https://github.com/tracel-ai/burn/issues/1178).
|
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
||||||
2. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
|
||||||
|
|
|
@ -83,10 +83,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
||||||
bias: Option<FloatTensor<Self, 1>>,
|
bias: Option<FloatTensor<Self, 1>>,
|
||||||
options: ConvTransposeOptions<1>,
|
options: ConvTransposeOptions<1>,
|
||||||
) -> FloatTensor<Self, 3> {
|
) -> FloatTensor<Self, 3> {
|
||||||
assert!(
|
|
||||||
options.groups == 1,
|
|
||||||
"Candle does not support groups in transposed convolutions"
|
|
||||||
);
|
|
||||||
let conv_transpose = x
|
let conv_transpose = x
|
||||||
.tensor
|
.tensor
|
||||||
.conv_transpose1d(
|
.conv_transpose1d(
|
||||||
|
@ -95,6 +91,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
||||||
options.padding_out[0],
|
options.padding_out[0],
|
||||||
options.stride[0],
|
options.stride[0],
|
||||||
options.dilation[0],
|
options.dilation[0],
|
||||||
|
options.groups,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
CandleTensor::new(match bias {
|
CandleTensor::new(match bias {
|
||||||
|
|
|
@ -10,7 +10,6 @@ burn-ndarray = { path = "../../burn-ndarray" }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
float-cmp = { workspace = true }
|
float-cmp = { workspace = true }
|
||||||
burn-import = { path = "../", features = ["pytorch"] }
|
burn-import = { path = "../", features = ["pytorch"] }
|
||||||
cfg-if = "1.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|
|
@ -1,23 +1,17 @@
|
||||||
cfg_if::cfg_if! {
|
mod batch_norm;
|
||||||
if #[cfg(not(target_os = "windows"))] {
|
mod boolean;
|
||||||
// The crate is not supported on Windows because of Candle's pt bug on Windows
|
mod buffer;
|
||||||
// (see https://github.com/huggingface/candle/issues/1454).
|
mod complex_nested;
|
||||||
mod batch_norm;
|
mod config;
|
||||||
mod boolean;
|
mod conv1d;
|
||||||
mod buffer;
|
mod conv2d;
|
||||||
mod complex_nested;
|
mod conv_transpose1d;
|
||||||
mod config;
|
mod conv_transpose2d;
|
||||||
mod conv1d;
|
mod embedding;
|
||||||
mod conv2d;
|
mod group_norm;
|
||||||
mod conv_transpose1d;
|
mod integer;
|
||||||
mod conv_transpose2d;
|
mod key_remap;
|
||||||
mod embedding;
|
mod key_remap_chained;
|
||||||
mod group_norm;
|
mod layer_norm;
|
||||||
mod integer;
|
mod linear;
|
||||||
mod key_remap;
|
mod missing_module_field;
|
||||||
mod key_remap_chained;
|
|
||||||
mod layer_norm;
|
|
||||||
mod linear;
|
|
||||||
mod missing_module_field;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -13,16 +13,6 @@ use burn_import::pytorch::PyTorchFileRecorder;
|
||||||
type B = NdArray<f32>;
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
if cfg!(target_os = "windows") {
|
|
||||||
println!(
|
|
||||||
"{}",
|
|
||||||
"cargo:warning=The crate is not supported on Windows because of ".to_owned()
|
|
||||||
+ "Candle's pt bug on Windows "
|
|
||||||
+ "(see https://github.com/huggingface/candle/issues/1454)."
|
|
||||||
);
|
|
||||||
std::process::exit(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
|
|
||||||
// Load PyTorch weights into a model record.
|
// Load PyTorch weights into a model record.
|
||||||
|
|
Loading…
Reference in New Issue