mirror of https://github.com/tracel-ai/burn.git
Fix pytorch recorder adapt_linear when using autodiff backend (#1576)
* Fix pytorch recorder adapt_linear when using autodiff backend * Fix comment typo
This commit is contained in:
parent
65222761fd
commit
ce898ff899
|
@ -7,6 +7,7 @@ license.workspace = true
|
|||
[dev-dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
burn-ndarray = { path = "../../burn-ndarray" }
|
||||
burn-autodiff = { path = "../../burn-autodiff" }
|
||||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
burn-import = { path = "../", features = ["pytorch"] }
|
||||
|
|
|
@ -12,6 +12,7 @@ use burn::{
|
|||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn_autodiff::Autodiff;
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
@ -164,6 +165,17 @@ fn full_record() {
|
|||
model_test(record, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_record_autodiff() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/complex_nested/complex_nested.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let device = Default::default();
|
||||
let _model = Net::<Autodiff<TestBackend>>::init(&device).load_record(record);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn half_record() {
|
||||
let device = Default::default();
|
||||
|
|
|
@ -36,6 +36,8 @@ impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS,
|
|||
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
|
||||
.expect("Failed to deserialize weight");
|
||||
|
||||
// Do not capture transpose op when using autodiff backend
|
||||
let weight = weight.set_require_grad(false);
|
||||
// Transpose the weight tensor.
|
||||
let weight_transposed = Param::from_tensor(weight.val().transpose());
|
||||
|
||||
|
|
Loading…
Reference in New Issue