mirror of https://github.com/tracel-ai/burn.git
Upgrade to candle 0.4.1 (#1382)
* Fix python main entrypoint in book example * Remove candle windows safeguards (#1178) * Bump candle-core from 0.3.3 to 0.4.1 * Remove windows current known issue
This commit is contained in:
parent
40bf3927f0
commit
4efc683df4
|
@ -691,9 +691,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.3.3"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6db8659ea87ee8197d2fc627348916cce0561330ee7ae3874e771691d3cecb2f"
|
||||
checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"byteorder",
|
||||
|
@ -718,18 +718,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.3.3"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d80cdd5f1cc60d30ba61353cdba5accd0fbc4d4ef2fe707fcb5179a9821adbea"
|
||||
checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
|
||||
dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.3.3"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52567e7a314ae0c59db5fbd4874ce461d99fa22adb22ddf7cf296b4d97035b40"
|
||||
checksum = "b20d6c0d49121e2709ed9faa958ba915ea59526036bcf27558817d1452a4ff09"
|
||||
dependencies = [
|
||||
"metal",
|
||||
"once_cell",
|
||||
|
@ -3521,7 +3521,6 @@ dependencies = [
|
|||
"burn",
|
||||
"burn-import",
|
||||
"burn-ndarray",
|
||||
"cfg-if",
|
||||
"float-cmp",
|
||||
"serde",
|
||||
]
|
||||
|
|
|
@ -25,7 +25,7 @@ license = "MIT OR Apache-2.0"
|
|||
[workspace.dependencies]
|
||||
async-trait = "0.1.74"
|
||||
bytemuck = "1.14"
|
||||
candle-core = { version = "0.3.3" }
|
||||
candle-core = { version = "0.4.1" }
|
||||
clap = { version = "4.5.1", features = ["derive"] }
|
||||
console_error_panic_hook = "0.1.7"
|
||||
csv = "1.3.0"
|
||||
|
|
|
@ -30,7 +30,7 @@ class Net(nn.Module):
|
|||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(42) # To make it reproducible
|
||||
model = Net().to(torch.device("cpu"))
|
||||
model_weights = model.state_dict()
|
||||
|
@ -254,5 +254,4 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
|
|||
|
||||
## Current known issues
|
||||
|
||||
1. [Candle's pickle library does not currently function on Windows due to a Candle bug](https://github.com/tracel-ai/burn/issues/1178).
|
||||
2. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
||||
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
||||
|
|
|
@ -83,10 +83,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
|||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self, 3> {
|
||||
assert!(
|
||||
options.groups == 1,
|
||||
"Candle does not support groups in transposed convolutions"
|
||||
);
|
||||
let conv_transpose = x
|
||||
.tensor
|
||||
.conv_transpose1d(
|
||||
|
@ -95,6 +91,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
|||
options.padding_out[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
options.groups,
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
|
|
|
@ -10,7 +10,6 @@ burn-ndarray = { path = "../../burn-ndarray" }
|
|||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
burn-import = { path = "../", features = ["pytorch"] }
|
||||
cfg-if = "1.0.0"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -1,23 +1,17 @@
|
|||
cfg_if::cfg_if! {
|
||||
if #[cfg(not(target_os = "windows"))] {
|
||||
// The crate is not supported on Windows because of Candle's pt bug on Windows
|
||||
// (see https://github.com/huggingface/candle/issues/1454).
|
||||
mod batch_norm;
|
||||
mod boolean;
|
||||
mod buffer;
|
||||
mod complex_nested;
|
||||
mod config;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod embedding;
|
||||
mod group_norm;
|
||||
mod integer;
|
||||
mod key_remap;
|
||||
mod key_remap_chained;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
mod missing_module_field;
|
||||
}
|
||||
}
|
||||
mod batch_norm;
|
||||
mod boolean;
|
||||
mod buffer;
|
||||
mod complex_nested;
|
||||
mod config;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod embedding;
|
||||
mod group_norm;
|
||||
mod integer;
|
||||
mod key_remap;
|
||||
mod key_remap_chained;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
mod missing_module_field;
|
||||
|
|
|
@ -13,16 +13,6 @@ use burn_import::pytorch::PyTorchFileRecorder;
|
|||
type B = NdArray<f32>;
|
||||
|
||||
fn main() {
|
||||
if cfg!(target_os = "windows") {
|
||||
println!(
|
||||
"{}",
|
||||
"cargo:warning=The crate is not supported on Windows because of ".to_owned()
|
||||
+ "Candle's pt bug on Windows "
|
||||
+ "(see https://github.com/huggingface/candle/issues/1454)."
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
// Load PyTorch weights into a model record.
|
||||
|
|
Loading…
Reference in New Issue