Enhance PyTorchRecorder to pass top-level key to extract state_dict (#1300)

* Enhance PyTorchRecorder to pass top level key to extract state_dict

This is needed for Whisper weight pt files.

* Fix missing hyphens

* Move top-level-key test under crates

* Add sub-crates as members of workspace

* Update Cargo.lock

* Add accidentally omitted line during merge
This commit is contained in:
Dilshod Tadjibaev 2024-02-29 12:57:27 -06:00 committed by GitHub
parent 4efc683df4
commit 688958ee74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 108 additions and 4 deletions

View File

@ -252,6 +252,24 @@ let model = Net::<Backend>::new_with(record);
with both an encoder and a decoder, it's possible to load only the encoder weights. This is done by
defining the encoder in Burn, allowing the loading of its weights while excluding the decoder's.
### Specifying the top-level key for state_dict
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
is nested under a top-level key along with other metadata as in a
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
In this case, you can specify the top-level key in `LoadArgs`:
```rust
let device = Default::default();
let load_args = LoadArgs::new("tiny.en.pt".into())
.with_top_level_key("my_state_dict");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully")
```
## Current known issues
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).

View File

@ -458,7 +458,7 @@ impl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A>
/// A default deserializer that always returns the default value.
struct DefaultDeserializer {
/// The originator field name (the top level missing field name)
/// The originator field name (the top-level missing field name)
originator_field_name: Option<String>,
}

View File

@ -15,3 +15,4 @@ mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;
mod top_level_key;

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
def forward(self, x):
x = self.conv1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt")
if __name__ == '__main__':
main()

View File

@ -0,0 +1,36 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
use super::*;
#[test]
#[should_panic]
fn should_fail_if_not_found() {
let device = Default::default();
let _record: NetRecord<Backend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/top_level_key/top_level_key.pt".into(), &device)
.expect("Should decode state successfully");
}
#[test]
fn should_load() {
let device = Default::default();
let load_args = LoadArgs::new("tests/top_level_key/top_level_key.pt".into())
.with_top_level_key("my_state_dict");
let _record: NetRecord<Backend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
}
}

View File

@ -30,14 +30,19 @@ use serde::{de::DeserializeOwned, Serialize};
///
/// * `path` - A string slice that holds the path of the file to read.
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
pub fn from_file<PS, D, B>(path: &Path, key_remap: Vec<(Regex, String)>) -> Result<D, Error>
/// * `top_level_key` - An optional top-level key to load state_dict from a dictionary.
pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
top_level_key: Option<&str>,
) -> Result<D, Error>
where
D: DeserializeOwned,
PS: PrecisionSettings,
B: Backend,
{
// Read the pickle file and return a vector of Candle tensors
let tensors: HashMap<String, CandleTensor> = pickle::read_all(path)?
let tensors: HashMap<String, CandleTensor> = pickle::read_all_with_key(path, top_level_key)?
.into_iter()
.map(|(key, tensor)| (key, CandleTensor(tensor)))
.collect();

View File

@ -44,7 +44,11 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
args: Self::LoadArgs,
device: &B::Device,
) -> Result<R, RecorderError> {
let item = from_file::<PS, R::Item<Self::Settings>, B>(&args.file, args.key_remap)?;
let item = from_file::<PS, R::Item<Self::Settings>, B>(
&args.file,
args.key_remap,
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
)?;
Ok(R::from_item(item, device))
}
}
@ -84,6 +88,10 @@ pub struct LoadArgs {
/// A list of key remappings.
pub key_remap: Vec<(Regex, String)>,
/// Top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict.
pub top_level_key: Option<String>,
}
impl LoadArgs {
@ -96,6 +104,7 @@ impl LoadArgs {
Self {
file,
key_remap: Vec::new(),
top_level_key: None,
}
}
@ -115,6 +124,17 @@ impl LoadArgs {
self.key_remap.push((regex, replacement.into()));
self
}
/// Set top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict.
///
/// # Arguments
///
/// * `key` - The top-level key to load state_dict from the file.
pub fn with_top_level_key(mut self, key: &str) -> Self {
self.top_level_key = Some(key.into());
self
}
}
impl From<PathBuf> for LoadArgs {