PyTorchFileRecord print debug option (#1425)

* Add debug print option to PyTorchFileRecorder

* Updated documentation and improved print output

* Improve print wording

* Updated per PR feedback
This commit is contained in:
Dilshod Tadjibaev 2024-03-06 16:11:37 -06:00 committed by GitHub
parent b429cc39c1
commit 545444c02a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 15 deletions

View File

@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in
You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes
by specifying a regular expression (See
[regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and
[try it online](https://rregex.dev/?version=1.10&method=replace)) to
match the attribute name and a replacement string in `LoadArgs`:
[try it online](https://rregex.dev/?version=1.10&method=replace)) to match the attribute name and a
replacement string in `LoadArgs`:
```rust
let device = Default::default();
@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
let model = Net::<Backend>::new_with(record);
```
### Printing the source model keys and tensor information
If you are unsure about the keys in the source model, you can print them using the following code:
```rust
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_key_remap("conv\\.(.*)", "$1")
.with_debug_print(); // Print the keys and remapped keys
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
```
Here is an example of the output:
```text
Debug information of keys and tensor shapes:
---
Original Key: conv.conv1.bias
Remapped Key: conv1.bias
Shape: [2]
Dtype: F32
---
Original Key: conv.conv1.weight
Remapped Key: conv1.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
Original Key: conv.conv2.weight
Remapped Key: conv2.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
```
### Loading the model weights to a partial model
`PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model
@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
### Specifying the top-level key for state_dict
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
Sometimes the
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
is nested under a top-level key along with other metadata as in a
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
In this case, you can specify the top-level key in `LoadArgs`:
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. In this
case, you can specify the top-level key in `LoadArgs`:
```rust
let device = Default::default();

View File

@ -182,19 +182,25 @@ impl NestedValue {
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
/// for more information.
///
/// # Returns
///
/// A map of tensors with the remapped keys.
/// A map of tensors with the remapped keys and
/// a vector of tuples containing the remapped and original.
pub fn remap<T>(
mut tensors: HashMap<String, T>,
key_remap: Vec<(Regex, String)>,
) -> HashMap<String, T> {
) -> (HashMap<String, T>, Vec<(String, String)>) {
if key_remap.is_empty() {
return tensors;
let remapped_names = tensors
.keys()
.cloned()
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
.collect();
return (tensors, remapped_names);
}
let mut remapped = HashMap::new();
let mut remapped_names = Vec::new();
for (name, tensor) in tensors.drain() {
let mut new_name = name.clone();
@ -205,10 +211,12 @@ pub fn remap<T>(
.to_string();
}
}
remapped_names.push((new_name.clone(), name));
remapped.insert(new_name, tensor);
}
remapped
(remapped, remapped_names)
}
/// Helper function to insert a value into a nested map/vector of tensors.

View File

@ -41,7 +41,8 @@ mod tests {
fn key_remap() {
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_debug_print();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)

View File

@ -35,6 +35,7 @@ pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
top_level_key: Option<&str>,
debug: bool,
) -> Result<D, Error>
where
D: DeserializeOwned,
@ -48,7 +49,28 @@ where
.collect();
// Remap the keys (replace the keys in the map with the new keys)
let tensors = remap(tensors, key_remap);
let (tensors, remapped_keys) = remap(tensors, key_remap);
// Print the remapped keys if debug is enabled
if debug {
let mut remapped_keys = remapped_keys;
remapped_keys.sort();
println!("Debug information of keys and tensor shapes:\n---");
for (new_key, old_key) in remapped_keys {
if old_key != new_key {
println!("Original Key: {old_key}");
println!("Remapped Key: {new_key}");
} else {
println!("Key: {}", new_key);
}
let shape = tensors[&new_key].shape();
let dtype = tensors[&new_key].dtype();
println!("Shape: {shape:?}");
println!("Dtype: {dtype:?}");
println!("---");
}
}
// Convert the vector of Candle tensors to a nested value data structure
let nested_value = unflatten::<PS, _>(tensors)?;

View File

@ -48,6 +48,7 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
&args.file,
args.key_remap,
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
args.debug,
)?;
Ok(R::from_item(item, device))
}
@ -92,10 +93,13 @@ pub struct LoadArgs {
/// Top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict.
pub top_level_key: Option<String>,
/// Whether to print debug information.
pub debug: bool,
}
impl LoadArgs {
/// Create a new `LoadArgs` instance.
/// Creates a new `LoadArgs` instance.
///
/// # Arguments
///
@ -105,10 +109,11 @@ impl LoadArgs {
file,
key_remap: Vec::new(),
top_level_key: None,
debug: false,
}
}
/// Set key remapping.
/// Sets key remapping.
///
/// # Arguments
///
@ -125,7 +130,7 @@ impl LoadArgs {
self
}
/// Set top-level key to load state_dict from the file.
/// Sets the top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict.
///
/// # Arguments
@ -135,6 +140,12 @@ impl LoadArgs {
self.top_level_key = Some(key.into());
self
}
/// Sets printing debug information on.
pub fn with_debug_print(mut self) -> Self {
self.debug = true;
self
}
}
impl From<PathBuf> for LoadArgs {