Fix pytorch recorder adapt_linear when using autodiff backend (#1576)

* Fix pytorch recorder adapt_linear when using autodiff backend

* Fix comment typo
This commit is contained in:
Guillaume Lagrange 2024-04-04 12:29:24 -04:00 committed by GitHub
parent 65222761fd
commit ce898ff899
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 0 deletions

View File

@ -7,6 +7,7 @@ license.workspace = true
[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
burn-autodiff = { path = "../../burn-autodiff" }
serde = { workspace = true }
float-cmp = { workspace = true }
burn-import = { path = "../", features = ["pytorch"] }

View File

@ -12,6 +12,7 @@ use burn::{
Tensor,
},
};
use burn_autodiff::Autodiff;
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
#[derive(Module, Debug)]
@ -164,6 +165,17 @@ fn full_record() {
model_test(record, 8);
}
#[test]
fn full_record_autodiff() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/complex_nested/complex_nested.pt".into(), &device)
.expect("Should decode state successfully");
let device = Default::default();
let _model = Net::<Autodiff<TestBackend>>::init(&device).load_record(record);
}
#[test]
fn half_record() {
let device = Default::default();

View File

@ -36,6 +36,8 @@ impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS,
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
.expect("Failed to deserialize weight");
// Do not capture transpose op when using autodiff backend
let weight = weight.set_require_grad(false);
// Transpose the weight tensor.
let weight_transposed = Param::from_tensor(weight.val().transpose());