mirror of https://github.com/tracel-ai/burn.git
PyTorchFileRecord print debug option (#1425)
* Add debug print option to PyTorchFileRecorder * Updated documentation and improved print output * Improve print wording * Updated per PR feedback
This commit is contained in:
parent
b429cc39c1
commit
545444c02a
|
@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in
|
|||
You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes
|
||||
by specifying a regular expression (See
|
||||
[regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and
|
||||
[try it online](https://rregex.dev/?version=1.10&method=replace)) to
|
||||
match the attribute name and a replacement string in `LoadArgs`:
|
||||
[try it online](https://rregex.dev/?version=1.10&method=replace)) to match the attribute name and a
|
||||
replacement string in `LoadArgs`:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
|
@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
|||
let model = Net::<Backend>::new_with(record);
|
||||
```
|
||||
|
||||
### Printing the source model keys and tensor information
|
||||
|
||||
If you are unsure about the keys in the source model, you can print them using the following code:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
.with_key_remap("conv\\.(.*)", "$1")
|
||||
.with_debug_print(); // Print the keys and remapped keys
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
```
|
||||
|
||||
Here is an example of the output:
|
||||
|
||||
```text
|
||||
Debug information of keys and tensor shapes:
|
||||
---
|
||||
Original Key: conv.conv1.bias
|
||||
Remapped Key: conv1.bias
|
||||
Shape: [2]
|
||||
Dtype: F32
|
||||
---
|
||||
Original Key: conv.conv1.weight
|
||||
Remapped Key: conv1.weight
|
||||
Shape: [2, 2, 2, 2]
|
||||
Dtype: F32
|
||||
---
|
||||
Original Key: conv.conv2.weight
|
||||
Remapped Key: conv2.weight
|
||||
Shape: [2, 2, 2, 2]
|
||||
Dtype: F32
|
||||
---
|
||||
```
|
||||
|
||||
### Loading the model weights to a partial model
|
||||
|
||||
`PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model
|
||||
|
@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
|
|||
|
||||
### Specifying the top-level key for state_dict
|
||||
|
||||
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
|
||||
Sometimes the
|
||||
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
|
||||
is nested under a top-level key along with other metadata as in a
|
||||
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
|
||||
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
|
||||
In this case, you can specify the top-level key in `LoadArgs`:
|
||||
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. In this
|
||||
case, you can specify the top-level key in `LoadArgs`:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
|
|
|
@ -182,19 +182,25 @@ impl NestedValue {
|
|||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
||||
/// for more information.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A map of tensors with the remapped keys.
|
||||
/// A map of tensors with the remapped keys and
|
||||
/// a vector of tuples containing the remapped and original.
|
||||
pub fn remap<T>(
|
||||
mut tensors: HashMap<String, T>,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
) -> HashMap<String, T> {
|
||||
) -> (HashMap<String, T>, Vec<(String, String)>) {
|
||||
if key_remap.is_empty() {
|
||||
return tensors;
|
||||
let remapped_names = tensors
|
||||
.keys()
|
||||
.cloned()
|
||||
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
|
||||
.collect();
|
||||
return (tensors, remapped_names);
|
||||
}
|
||||
|
||||
let mut remapped = HashMap::new();
|
||||
let mut remapped_names = Vec::new();
|
||||
|
||||
for (name, tensor) in tensors.drain() {
|
||||
let mut new_name = name.clone();
|
||||
|
@ -205,10 +211,12 @@ pub fn remap<T>(
|
|||
.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
remapped_names.push((new_name.clone(), name));
|
||||
remapped.insert(new_name, tensor);
|
||||
}
|
||||
|
||||
remapped
|
||||
(remapped, remapped_names)
|
||||
}
|
||||
|
||||
/// Helper function to insert a value into a nested map/vector of tensors.
|
||||
|
|
|
@ -41,7 +41,8 @@ mod tests {
|
|||
fn key_remap() {
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
.with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
.with_debug_print();
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
|
|
|
@ -35,6 +35,7 @@ pub fn from_file<PS, D, B>(
|
|||
path: &Path,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
top_level_key: Option<&str>,
|
||||
debug: bool,
|
||||
) -> Result<D, Error>
|
||||
where
|
||||
D: DeserializeOwned,
|
||||
|
@ -48,7 +49,28 @@ where
|
|||
.collect();
|
||||
|
||||
// Remap the keys (replace the keys in the map with the new keys)
|
||||
let tensors = remap(tensors, key_remap);
|
||||
let (tensors, remapped_keys) = remap(tensors, key_remap);
|
||||
|
||||
// Print the remapped keys if debug is enabled
|
||||
if debug {
|
||||
let mut remapped_keys = remapped_keys;
|
||||
remapped_keys.sort();
|
||||
println!("Debug information of keys and tensor shapes:\n---");
|
||||
for (new_key, old_key) in remapped_keys {
|
||||
if old_key != new_key {
|
||||
println!("Original Key: {old_key}");
|
||||
println!("Remapped Key: {new_key}");
|
||||
} else {
|
||||
println!("Key: {}", new_key);
|
||||
}
|
||||
|
||||
let shape = tensors[&new_key].shape();
|
||||
let dtype = tensors[&new_key].dtype();
|
||||
println!("Shape: {shape:?}");
|
||||
println!("Dtype: {dtype:?}");
|
||||
println!("---");
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the vector of Candle tensors to a nested value data structure
|
||||
let nested_value = unflatten::<PS, _>(tensors)?;
|
||||
|
|
|
@ -48,6 +48,7 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
|
|||
&args.file,
|
||||
args.key_remap,
|
||||
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
|
||||
args.debug,
|
||||
)?;
|
||||
Ok(R::from_item(item, device))
|
||||
}
|
||||
|
@ -92,10 +93,13 @@ pub struct LoadArgs {
|
|||
/// Top-level key to load state_dict from the file.
|
||||
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
||||
pub top_level_key: Option<String>,
|
||||
|
||||
/// Whether to print debug information.
|
||||
pub debug: bool,
|
||||
}
|
||||
|
||||
impl LoadArgs {
|
||||
/// Create a new `LoadArgs` instance.
|
||||
/// Creates a new `LoadArgs` instance.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -105,10 +109,11 @@ impl LoadArgs {
|
|||
file,
|
||||
key_remap: Vec::new(),
|
||||
top_level_key: None,
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set key remapping.
|
||||
/// Sets key remapping.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -125,7 +130,7 @@ impl LoadArgs {
|
|||
self
|
||||
}
|
||||
|
||||
/// Set top-level key to load state_dict from the file.
|
||||
/// Sets the top-level key to load state_dict from the file.
|
||||
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -135,6 +140,12 @@ impl LoadArgs {
|
|||
self.top_level_key = Some(key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets printing debug information on.
|
||||
pub fn with_debug_print(mut self) -> Self {
|
||||
self.debug = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LoadArgs {
|
||||
|
|
Loading…
Reference in New Issue