mirror of https://github.com/tracel-ai/burn.git
Enhance PyTorchRecorder to pass top-level key to extract state_dict (#1300)
* Enhance PyTorchRecorder to pass top level key to extract state_dict This is needed for Whisper weight pt files. * Fix missing hyphens * Move top-level-key test under crates * Add sub-crates as members of workspace * Update Cargo.lock * Add accidentally omitted line during merge
This commit is contained in:
parent
4efc683df4
commit
688958ee74
|
@ -252,6 +252,24 @@ let model = Net::<Backend>::new_with(record);
|
|||
with both an encoder and a decoder, it's possible to load only the encoder weights. This is done by
|
||||
defining the encoder in Burn, allowing the loading of its weights while excluding the decoder's.
|
||||
|
||||
### Specifying the top-level key for state_dict
|
||||
|
||||
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
|
||||
is nested under a top-level key along with other metadata as in a
|
||||
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
|
||||
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
|
||||
In this case, you can specify the top-level key in `LoadArgs`:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tiny.en.pt".into())
|
||||
.with_top_level_key("my_state_dict");
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully")
|
||||
```
|
||||
|
||||
## Current known issues
|
||||
|
||||
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
||||
|
|
|
@ -458,7 +458,7 @@ impl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A>
|
|||
|
||||
/// A default deserializer that always returns the default value.
|
||||
struct DefaultDeserializer {
|
||||
/// The originator field name (the top level missing field name)
|
||||
/// The originator field name (the top-level missing field name)
|
||||
originator_field_name: Option<String>,
|
||||
}
|
||||
|
||||
|
|
|
@ -15,3 +15,4 @@ mod key_remap_chained;
|
|||
mod layer_norm;
|
||||
mod linear;
|
||||
mod missing_module_field;
|
||||
mod top_level_key;
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
model = Model().to(torch.device("cpu"))
|
||||
torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,36 @@
|
|||
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_fail_if_not_found() {
|
||||
let device = Default::default();
|
||||
let _record: NetRecord<Backend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/top_level_key/top_level_key.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load() {
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/top_level_key/top_level_key.pt".into())
|
||||
.with_top_level_key("my_state_dict");
|
||||
|
||||
let _record: NetRecord<Backend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -30,14 +30,19 @@ use serde::{de::DeserializeOwned, Serialize};
|
|||
///
|
||||
/// * `path` - A string slice that holds the path of the file to read.
|
||||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||
pub fn from_file<PS, D, B>(path: &Path, key_remap: Vec<(Regex, String)>) -> Result<D, Error>
|
||||
/// * `top_level_key` - An optional top-level key to load state_dict from a dictionary.
|
||||
pub fn from_file<PS, D, B>(
|
||||
path: &Path,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
top_level_key: Option<&str>,
|
||||
) -> Result<D, Error>
|
||||
where
|
||||
D: DeserializeOwned,
|
||||
PS: PrecisionSettings,
|
||||
B: Backend,
|
||||
{
|
||||
// Read the pickle file and return a vector of Candle tensors
|
||||
let tensors: HashMap<String, CandleTensor> = pickle::read_all(path)?
|
||||
let tensors: HashMap<String, CandleTensor> = pickle::read_all_with_key(path, top_level_key)?
|
||||
.into_iter()
|
||||
.map(|(key, tensor)| (key, CandleTensor(tensor)))
|
||||
.collect();
|
||||
|
|
|
@ -44,7 +44,11 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
|
|||
args: Self::LoadArgs,
|
||||
device: &B::Device,
|
||||
) -> Result<R, RecorderError> {
|
||||
let item = from_file::<PS, R::Item<Self::Settings>, B>(&args.file, args.key_remap)?;
|
||||
let item = from_file::<PS, R::Item<Self::Settings>, B>(
|
||||
&args.file,
|
||||
args.key_remap,
|
||||
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
|
||||
)?;
|
||||
Ok(R::from_item(item, device))
|
||||
}
|
||||
}
|
||||
|
@ -84,6 +88,10 @@ pub struct LoadArgs {
|
|||
|
||||
/// A list of key remappings.
|
||||
pub key_remap: Vec<(Regex, String)>,
|
||||
|
||||
/// Top-level key to load state_dict from the file.
|
||||
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
||||
pub top_level_key: Option<String>,
|
||||
}
|
||||
|
||||
impl LoadArgs {
|
||||
|
@ -96,6 +104,7 @@ impl LoadArgs {
|
|||
Self {
|
||||
file,
|
||||
key_remap: Vec::new(),
|
||||
top_level_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,6 +124,17 @@ impl LoadArgs {
|
|||
self.key_remap.push((regex, replacement.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Set top-level key to load state_dict from the file.
|
||||
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - The top-level key to load state_dict from the file.
|
||||
pub fn with_top_level_key(mut self, key: &str) -> Self {
|
||||
self.top_level_key = Some(key.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LoadArgs {
|
||||
|
|
Loading…
Reference in New Issue