PyTorchFileRecord print debug option (#1425)

* Add debug print option to PyTorchFileRecorder

* Updated documentation and improved print output

* Improve print wording

* Updated per PR feedback
This commit is contained in:
Dilshod Tadjibaev 2024-03-06 16:11:37 -06:00 committed by GitHub
parent b429cc39c1
commit 545444c02a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 15 deletions

View File

@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in
You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes
by specifying a regular expression (See by specifying a regular expression (See
[regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and
[try it online](https://rregex.dev/?version=1.10&method=replace)) to [try it online](https://rregex.dev/?version=1.10&method=replace)) to match the attribute name and a
match the attribute name and a replacement string in `LoadArgs`: replacement string in `LoadArgs`:
```rust ```rust
let device = Default::default(); let device = Default::default();
@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
let model = Net::<Backend>::new_with(record); let model = Net::<Backend>::new_with(record);
``` ```
### Printing the source model keys and tensor information
If you are unsure about the keys in the source model, you can print them using the following code:
```rust
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_key_remap("conv\\.(.*)", "$1")
.with_debug_print(); // Print the keys and remapped keys
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
```
Here is an example of the output:
```text
Debug information of keys and tensor shapes:
---
Original Key: conv.conv1.bias
Remapped Key: conv1.bias
Shape: [2]
Dtype: F32
---
Original Key: conv.conv1.weight
Remapped Key: conv1.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
Original Key: conv.conv2.weight
Remapped Key: conv2.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
```
### Loading the model weights to a partial model ### Loading the model weights to a partial model
`PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model `PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model
@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
### Specifying the top-level key for state_dict ### Specifying the top-level key for state_dict
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict) Sometimes the
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
is nested under a top-level key along with other metadata as in a is nested under a top-level key along with other metadata as in a
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training). [general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. In this
In this case, you can specify the top-level key in `LoadArgs`: case, you can specify the top-level key in `LoadArgs`:
```rust ```rust
let device = Default::default(); let device = Default::default();

View File

@ -182,19 +182,25 @@ impl NestedValue {
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string. /// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) /// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
/// for more information. /// for more information.
///
/// # Returns /// # Returns
/// ///
/// A map of tensors with the remapped keys. /// A map of tensors with the remapped keys and
/// a vector of tuples containing the remapped and original.
pub fn remap<T>( pub fn remap<T>(
mut tensors: HashMap<String, T>, mut tensors: HashMap<String, T>,
key_remap: Vec<(Regex, String)>, key_remap: Vec<(Regex, String)>,
) -> HashMap<String, T> { ) -> (HashMap<String, T>, Vec<(String, String)>) {
if key_remap.is_empty() { if key_remap.is_empty() {
return tensors; let remapped_names = tensors
.keys()
.cloned()
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
.collect();
return (tensors, remapped_names);
} }
let mut remapped = HashMap::new(); let mut remapped = HashMap::new();
let mut remapped_names = Vec::new();
for (name, tensor) in tensors.drain() { for (name, tensor) in tensors.drain() {
let mut new_name = name.clone(); let mut new_name = name.clone();
@ -205,10 +211,12 @@ pub fn remap<T>(
.to_string(); .to_string();
} }
} }
remapped_names.push((new_name.clone(), name));
remapped.insert(new_name, tensor); remapped.insert(new_name, tensor);
} }
remapped (remapped, remapped_names)
} }
/// Helper function to insert a value into a nested map/vector of tensors. /// Helper function to insert a value into a nested map/vector of tensors.

View File

@ -41,7 +41,8 @@ mod tests {
fn key_remap() { fn key_remap() {
let device = Default::default(); let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into()) let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" .with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_debug_print();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default() let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device) .load(load_args, &device)

View File

@ -35,6 +35,7 @@ pub fn from_file<PS, D, B>(
path: &Path, path: &Path,
key_remap: Vec<(Regex, String)>, key_remap: Vec<(Regex, String)>,
top_level_key: Option<&str>, top_level_key: Option<&str>,
debug: bool,
) -> Result<D, Error> ) -> Result<D, Error>
where where
D: DeserializeOwned, D: DeserializeOwned,
@ -48,7 +49,28 @@ where
.collect(); .collect();
// Remap the keys (replace the keys in the map with the new keys) // Remap the keys (replace the keys in the map with the new keys)
let tensors = remap(tensors, key_remap); let (tensors, remapped_keys) = remap(tensors, key_remap);
// Print the remapped keys if debug is enabled
if debug {
let mut remapped_keys = remapped_keys;
remapped_keys.sort();
println!("Debug information of keys and tensor shapes:\n---");
for (new_key, old_key) in remapped_keys {
if old_key != new_key {
println!("Original Key: {old_key}");
println!("Remapped Key: {new_key}");
} else {
println!("Key: {}", new_key);
}
let shape = tensors[&new_key].shape();
let dtype = tensors[&new_key].dtype();
println!("Shape: {shape:?}");
println!("Dtype: {dtype:?}");
println!("---");
}
}
// Convert the vector of Candle tensors to a nested value data structure // Convert the vector of Candle tensors to a nested value data structure
let nested_value = unflatten::<PS, _>(tensors)?; let nested_value = unflatten::<PS, _>(tensors)?;

View File

@ -48,6 +48,7 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
&args.file, &args.file,
args.key_remap, args.key_remap,
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str> args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
args.debug,
)?; )?;
Ok(R::from_item(item, device)) Ok(R::from_item(item, device))
} }
@ -92,10 +93,13 @@ pub struct LoadArgs {
/// Top-level key to load state_dict from the file. /// Top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict. /// Sometimes the state_dict is nested under a top-level key in a dict.
pub top_level_key: Option<String>, pub top_level_key: Option<String>,
/// Whether to print debug information.
pub debug: bool,
} }
impl LoadArgs { impl LoadArgs {
/// Create a new `LoadArgs` instance. /// Creates a new `LoadArgs` instance.
/// ///
/// # Arguments /// # Arguments
/// ///
@ -105,10 +109,11 @@ impl LoadArgs {
file, file,
key_remap: Vec::new(), key_remap: Vec::new(),
top_level_key: None, top_level_key: None,
debug: false,
} }
} }
/// Set key remapping. /// Sets key remapping.
/// ///
/// # Arguments /// # Arguments
/// ///
@ -125,7 +130,7 @@ impl LoadArgs {
self self
} }
/// Set top-level key to load state_dict from the file. /// Sets the top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict. /// Sometimes the state_dict is nested under a top-level key in a dict.
/// ///
/// # Arguments /// # Arguments
@ -135,6 +140,12 @@ impl LoadArgs {
self.top_level_key = Some(key.into()); self.top_level_key = Some(key.into());
self self
} }
/// Sets printing debug information on.
pub fn with_debug_print(mut self) -> Self {
self.debug = true;
self
}
} }
impl From<PathBuf> for LoadArgs { impl From<PathBuf> for LoadArgs {