diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index a2356a16e..12d81c8f2 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -252,6 +252,24 @@ let model = Net::::new_with(record); with both an encoder and a decoder, it's possible to load only the encoder weights. This is done by defining the encoder in Burn, allowing the loading of its weights while excluding the decoder's. +### Specifying the top-level key for state_dict + +Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict) +is nested under a top-level key along with other metadata as in a +[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training). +For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. +In this case, you can specify the top-level key in `LoadArgs`: + +```rust +let device = Default::default(); +let load_args = LoadArgs::new("tiny.en.pt".into()) + .with_top_level_key("my_state_dict"); + +let record = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully") +``` + ## Current known issues 1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179). diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 1effa6661..5407dcded 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -458,7 +458,7 @@ impl IntoDeserializer<'_, Error> for NestedValueWrapper /// A default deserializer that always returns the default value. struct DefaultDeserializer { - /// The originator field name (the top level missing field name) + /// The originator field name (the top-level missing field name) originator_field_name: Option, } diff --git a/crates/burn-import/pytorch-tests/tests/mod.rs b/crates/burn-import/pytorch-tests/tests/mod.rs index 2e990977b..267930642 100644 --- a/crates/burn-import/pytorch-tests/tests/mod.rs +++ b/crates/burn-import/pytorch-tests/tests/mod.rs @@ -15,3 +15,4 @@ mod key_remap_chained; mod layer_norm; mod linear; mod missing_module_field; +mod top_level_key; diff --git a/crates/burn-import/pytorch-tests/tests/top_level_key/export_weights.py b/crates/burn-import/pytorch-tests/tests/top_level_key/export_weights.py new file mode 100755 index 000000000..40456292d --- /dev/null +++ b/crates/burn-import/pytorch-tests/tests/top_level_key/export_weights.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(2, 2, (2,2)) + + def forward(self, x): + x = self.conv1(x) + return x + + +def main(): + torch.set_printoptions(precision=8) + torch.manual_seed(1) + model = Model().to(torch.device("cpu")) + torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt") + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/pytorch-tests/tests/top_level_key/mod.rs b/crates/burn-import/pytorch-tests/tests/top_level_key/mod.rs new file mode 100644 index 000000000..c05278edd --- /dev/null +++ b/crates/burn-import/pytorch-tests/tests/top_level_key/mod.rs @@ -0,0 +1,36 @@ +use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend}; + +#[derive(Module, Debug)] +pub struct Net { + conv1: Conv2d, +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; + + use super::*; + + #[test] + #[should_panic] + fn should_fail_if_not_found() { + let device = Default::default(); + let _record: NetRecord = PyTorchFileRecorder::::default() + .load("tests/top_level_key/top_level_key.pt".into(), &device) + .expect("Should decode state successfully"); + } + + #[test] + fn should_load() { + let device = Default::default(); + let load_args = LoadArgs::new("tests/top_level_key/top_level_key.pt".into()) + .with_top_level_key("my_state_dict"); + + let _record: NetRecord = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + } +} diff --git a/crates/burn-import/pytorch-tests/tests/top_level_key/top_level_key.pt b/crates/burn-import/pytorch-tests/tests/top_level_key/top_level_key.pt new file mode 100644 index 000000000..2658e0779 Binary files /dev/null and b/crates/burn-import/pytorch-tests/tests/top_level_key/top_level_key.pt differ diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs index 9acb66573..61e62e8a2 100644 --- a/crates/burn-import/src/pytorch/reader.rs +++ b/crates/burn-import/src/pytorch/reader.rs @@ -30,14 +30,19 @@ use serde::{de::DeserializeOwned, Serialize}; /// /// * `path` - A string slice that holds the path of the file to read. /// * `key_remap` - A vector of tuples containing a regular expression and a replacement string. -pub fn from_file(path: &Path, key_remap: Vec<(Regex, String)>) -> Result +/// * `top_level_key` - An optional top-level key to load state_dict from a dictionary. +pub fn from_file( + path: &Path, + key_remap: Vec<(Regex, String)>, + top_level_key: Option<&str>, +) -> Result where D: DeserializeOwned, PS: PrecisionSettings, B: Backend, { // Read the pickle file and return a vector of Candle tensors - let tensors: HashMap = pickle::read_all(path)? + let tensors: HashMap = pickle::read_all_with_key(path, top_level_key)? .into_iter() .map(|(key, tensor)| (key, CandleTensor(tensor))) .collect(); diff --git a/crates/burn-import/src/pytorch/recorder.rs b/crates/burn-import/src/pytorch/recorder.rs index ea4429803..e1888c849 100644 --- a/crates/burn-import/src/pytorch/recorder.rs +++ b/crates/burn-import/src/pytorch/recorder.rs @@ -44,7 +44,11 @@ impl Recorder for PyTorchFileRecorder args: Self::LoadArgs, device: &B::Device, ) -> Result { - let item = from_file::, B>(&args.file, args.key_remap)?; + let item = from_file::, B>( + &args.file, + args.key_remap, + args.top_level_key.as_deref(), // Convert Option to Option<&str> + )?; Ok(R::from_item(item, device)) } } @@ -84,6 +88,10 @@ pub struct LoadArgs { /// A list of key remappings. pub key_remap: Vec<(Regex, String)>, + + /// Top-level key to load state_dict from the file. + /// Sometimes the state_dict is nested under a top-level key in a dict. + pub top_level_key: Option, } impl LoadArgs { @@ -96,6 +104,7 @@ impl LoadArgs { Self { file, key_remap: Vec::new(), + top_level_key: None, } } @@ -115,6 +124,17 @@ impl LoadArgs { self.key_remap.push((regex, replacement.into())); self } + + /// Set top-level key to load state_dict from the file. + /// Sometimes the state_dict is nested under a top-level key in a dict. + /// + /// # Arguments + /// + /// * `key` - The top-level key to load state_dict from the file. + pub fn with_top_level_key(mut self, key: &str) -> Self { + self.top_level_key = Some(key.into()); + self + } } impl From for LoadArgs {