mirror of https://github.com/tracel-ai/burn.git
PyTorchFileRecord print debug option (#1425)
* Add debug print option to PyTorchFileRecorder * Updated documentation and improved print output * Improve print wording * Updated per PR feedback
This commit is contained in:
parent
b429cc39c1
commit
545444c02a
|
@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in
|
||||||
You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes
|
You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes
|
||||||
by specifying a regular expression (See
|
by specifying a regular expression (See
|
||||||
[regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and
|
[regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and
|
||||||
[try it online](https://rregex.dev/?version=1.10&method=replace)) to
|
[try it online](https://rregex.dev/?version=1.10&method=replace)) to match the attribute name and a
|
||||||
match the attribute name and a replacement string in `LoadArgs`:
|
replacement string in `LoadArgs`:
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
|
@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||||
let model = Net::<Backend>::new_with(record);
|
let model = Net::<Backend>::new_with(record);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Printing the source model keys and tensor information
|
||||||
|
|
||||||
|
If you are unsure about the keys in the source model, you can print them using the following code:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let device = Default::default();
|
||||||
|
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||||
|
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||||
|
.with_key_remap("conv\\.(.*)", "$1")
|
||||||
|
.with_debug_print(); // Print the keys and remapped keys
|
||||||
|
|
||||||
|
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||||
|
.load(load_args, &device)
|
||||||
|
.expect("Should decode state successfully");
|
||||||
|
|
||||||
|
let model = Net::<Backend>::new_with(record);
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is an example of the output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Debug information of keys and tensor shapes:
|
||||||
|
---
|
||||||
|
Original Key: conv.conv1.bias
|
||||||
|
Remapped Key: conv1.bias
|
||||||
|
Shape: [2]
|
||||||
|
Dtype: F32
|
||||||
|
---
|
||||||
|
Original Key: conv.conv1.weight
|
||||||
|
Remapped Key: conv1.weight
|
||||||
|
Shape: [2, 2, 2, 2]
|
||||||
|
Dtype: F32
|
||||||
|
---
|
||||||
|
Original Key: conv.conv2.weight
|
||||||
|
Remapped Key: conv2.weight
|
||||||
|
Shape: [2, 2, 2, 2]
|
||||||
|
Dtype: F32
|
||||||
|
---
|
||||||
|
```
|
||||||
|
|
||||||
### Loading the model weights to a partial model
|
### Loading the model weights to a partial model
|
||||||
|
|
||||||
`PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model
|
`PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model
|
||||||
|
@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
|
||||||
|
|
||||||
### Specifying the top-level key for state_dict
|
### Specifying the top-level key for state_dict
|
||||||
|
|
||||||
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
|
Sometimes the
|
||||||
|
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
|
||||||
is nested under a top-level key along with other metadata as in a
|
is nested under a top-level key along with other metadata as in a
|
||||||
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
|
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
|
||||||
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
|
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. In this
|
||||||
In this case, you can specify the top-level key in `LoadArgs`:
|
case, you can specify the top-level key in `LoadArgs`:
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
|
|
|
@ -182,19 +182,25 @@ impl NestedValue {
|
||||||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||||
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
||||||
/// for more information.
|
/// for more information.
|
||||||
///
|
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// A map of tensors with the remapped keys.
|
/// A map of tensors with the remapped keys and
|
||||||
|
/// a vector of tuples containing the remapped and original.
|
||||||
pub fn remap<T>(
|
pub fn remap<T>(
|
||||||
mut tensors: HashMap<String, T>,
|
mut tensors: HashMap<String, T>,
|
||||||
key_remap: Vec<(Regex, String)>,
|
key_remap: Vec<(Regex, String)>,
|
||||||
) -> HashMap<String, T> {
|
) -> (HashMap<String, T>, Vec<(String, String)>) {
|
||||||
if key_remap.is_empty() {
|
if key_remap.is_empty() {
|
||||||
return tensors;
|
let remapped_names = tensors
|
||||||
|
.keys()
|
||||||
|
.cloned()
|
||||||
|
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
|
||||||
|
.collect();
|
||||||
|
return (tensors, remapped_names);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut remapped = HashMap::new();
|
let mut remapped = HashMap::new();
|
||||||
|
let mut remapped_names = Vec::new();
|
||||||
|
|
||||||
for (name, tensor) in tensors.drain() {
|
for (name, tensor) in tensors.drain() {
|
||||||
let mut new_name = name.clone();
|
let mut new_name = name.clone();
|
||||||
|
@ -205,10 +211,12 @@ pub fn remap<T>(
|
||||||
.to_string();
|
.to_string();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remapped_names.push((new_name.clone(), name));
|
||||||
remapped.insert(new_name, tensor);
|
remapped.insert(new_name, tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
remapped
|
(remapped, remapped_names)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper function to insert a value into a nested map/vector of tensors.
|
/// Helper function to insert a value into a nested map/vector of tensors.
|
||||||
|
|
|
@ -41,7 +41,8 @@ mod tests {
|
||||||
fn key_remap() {
|
fn key_remap() {
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||||
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
.with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||||
|
.with_debug_print();
|
||||||
|
|
||||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||||
.load(load_args, &device)
|
.load(load_args, &device)
|
||||||
|
|
|
@ -35,6 +35,7 @@ pub fn from_file<PS, D, B>(
|
||||||
path: &Path,
|
path: &Path,
|
||||||
key_remap: Vec<(Regex, String)>,
|
key_remap: Vec<(Regex, String)>,
|
||||||
top_level_key: Option<&str>,
|
top_level_key: Option<&str>,
|
||||||
|
debug: bool,
|
||||||
) -> Result<D, Error>
|
) -> Result<D, Error>
|
||||||
where
|
where
|
||||||
D: DeserializeOwned,
|
D: DeserializeOwned,
|
||||||
|
@ -48,7 +49,28 @@ where
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Remap the keys (replace the keys in the map with the new keys)
|
// Remap the keys (replace the keys in the map with the new keys)
|
||||||
let tensors = remap(tensors, key_remap);
|
let (tensors, remapped_keys) = remap(tensors, key_remap);
|
||||||
|
|
||||||
|
// Print the remapped keys if debug is enabled
|
||||||
|
if debug {
|
||||||
|
let mut remapped_keys = remapped_keys;
|
||||||
|
remapped_keys.sort();
|
||||||
|
println!("Debug information of keys and tensor shapes:\n---");
|
||||||
|
for (new_key, old_key) in remapped_keys {
|
||||||
|
if old_key != new_key {
|
||||||
|
println!("Original Key: {old_key}");
|
||||||
|
println!("Remapped Key: {new_key}");
|
||||||
|
} else {
|
||||||
|
println!("Key: {}", new_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let shape = tensors[&new_key].shape();
|
||||||
|
let dtype = tensors[&new_key].dtype();
|
||||||
|
println!("Shape: {shape:?}");
|
||||||
|
println!("Dtype: {dtype:?}");
|
||||||
|
println!("---");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert the vector of Candle tensors to a nested value data structure
|
// Convert the vector of Candle tensors to a nested value data structure
|
||||||
let nested_value = unflatten::<PS, _>(tensors)?;
|
let nested_value = unflatten::<PS, _>(tensors)?;
|
||||||
|
|
|
@ -48,6 +48,7 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
|
||||||
&args.file,
|
&args.file,
|
||||||
args.key_remap,
|
args.key_remap,
|
||||||
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
|
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
|
||||||
|
args.debug,
|
||||||
)?;
|
)?;
|
||||||
Ok(R::from_item(item, device))
|
Ok(R::from_item(item, device))
|
||||||
}
|
}
|
||||||
|
@ -92,10 +93,13 @@ pub struct LoadArgs {
|
||||||
/// Top-level key to load state_dict from the file.
|
/// Top-level key to load state_dict from the file.
|
||||||
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
||||||
pub top_level_key: Option<String>,
|
pub top_level_key: Option<String>,
|
||||||
|
|
||||||
|
/// Whether to print debug information.
|
||||||
|
pub debug: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadArgs {
|
impl LoadArgs {
|
||||||
/// Create a new `LoadArgs` instance.
|
/// Creates a new `LoadArgs` instance.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
|
@ -105,10 +109,11 @@ impl LoadArgs {
|
||||||
file,
|
file,
|
||||||
key_remap: Vec::new(),
|
key_remap: Vec::new(),
|
||||||
top_level_key: None,
|
top_level_key: None,
|
||||||
|
debug: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set key remapping.
|
/// Sets key remapping.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
|
@ -125,7 +130,7 @@ impl LoadArgs {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set top-level key to load state_dict from the file.
|
/// Sets the top-level key to load state_dict from the file.
|
||||||
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
/// Sometimes the state_dict is nested under a top-level key in a dict.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -135,6 +140,12 @@ impl LoadArgs {
|
||||||
self.top_level_key = Some(key.into());
|
self.top_level_key = Some(key.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sets printing debug information on.
|
||||||
|
pub fn with_debug_print(mut self) -> Self {
|
||||||
|
self.debug = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<PathBuf> for LoadArgs {
|
impl From<PathBuf> for LoadArgs {
|
||||||
|
|
Loading…
Reference in New Issue