diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 12d81c8f2..6b2e73879 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes by specifying a regular expression (See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and -[try it online](https://rregex.dev/?version=1.10&method=replace)) to -match the attribute name and a replacement string in `LoadArgs`: +[try it online](https://rregex.dev/?version=1.10&method=replace)) to match the attribute name and a +replacement string in `LoadArgs`: ```rust let device = Default::default(); @@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::::default() let model = Net::::new_with(record); ``` +### Printing the source model keys and tensor information + +If you are unsure about the keys in the source model, you can print them using the following code: + +```rust +let device = Default::default(); +let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into()) + // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" + .with_key_remap("conv\\.(.*)", "$1") + .with_debug_print(); // Print the keys and remapped keys + +let record = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + +let model = Net::::new_with(record); +``` + +Here is an example of the output: + +```text +Debug information of keys and tensor shapes: +--- +Original Key: conv.conv1.bias +Remapped Key: conv1.bias +Shape: [2] +Dtype: F32 +--- +Original Key: conv.conv1.weight +Remapped Key: conv1.weight +Shape: [2, 2, 2, 2] +Dtype: F32 +--- +Original Key: conv.conv2.weight +Remapped Key: conv2.weight +Shape: [2, 2, 2, 2] +Dtype: F32 +--- +``` + ### Loading the model weights to a partial model `PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model @@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin ### Specifying the top-level key for state_dict -Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict) +Sometimes the +[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict) is nested under a top-level key along with other metadata as in a [general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training). -For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. -In this case, you can specify the top-level key in `LoadArgs`: +For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. In this +case, you can specify the top-level key in `LoadArgs`: ```rust let device = Default::default(); diff --git a/crates/burn-core/src/record/serde/data.rs b/crates/burn-core/src/record/serde/data.rs index 92e057616..ed612ae4a 100644 --- a/crates/burn-core/src/record/serde/data.rs +++ b/crates/burn-core/src/record/serde/data.rs @@ -182,19 +182,25 @@ impl NestedValue { /// * `key_remap` - A vector of tuples containing a regular expression and a replacement string. /// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) /// for more information. -/// /// # Returns /// -/// A map of tensors with the remapped keys. +/// A map of tensors with the remapped keys and +/// a vector of tuples containing the remapped and original. pub fn remap( mut tensors: HashMap, key_remap: Vec<(Regex, String)>, -) -> HashMap { +) -> (HashMap, Vec<(String, String)>) { if key_remap.is_empty() { - return tensors; + let remapped_names = tensors + .keys() + .cloned() + .map(|s| (s.clone(), s)) // Name is the same as the remapped name + .collect(); + return (tensors, remapped_names); } let mut remapped = HashMap::new(); + let mut remapped_names = Vec::new(); for (name, tensor) in tensors.drain() { let mut new_name = name.clone(); @@ -205,10 +211,12 @@ pub fn remap( .to_string(); } } + + remapped_names.push((new_name.clone(), name)); remapped.insert(new_name, tensor); } - remapped + (remapped, remapped_names) } /// Helper function to insert a value into a nested map/vector of tensors. diff --git a/crates/burn-import/pytorch-tests/tests/key_remap/mod.rs b/crates/burn-import/pytorch-tests/tests/key_remap/mod.rs index 6c5634915..cec61d56d 100644 --- a/crates/burn-import/pytorch-tests/tests/key_remap/mod.rs +++ b/crates/burn-import/pytorch-tests/tests/key_remap/mod.rs @@ -41,7 +41,8 @@ mod tests { fn key_remap() { let device = Default::default(); let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into()) - .with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" + .with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" + .with_debug_print(); let record = PyTorchFileRecorder::::default() .load(load_args, &device) diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs index 61e62e8a2..c327d92b6 100644 --- a/crates/burn-import/src/pytorch/reader.rs +++ b/crates/burn-import/src/pytorch/reader.rs @@ -35,6 +35,7 @@ pub fn from_file( path: &Path, key_remap: Vec<(Regex, String)>, top_level_key: Option<&str>, + debug: bool, ) -> Result where D: DeserializeOwned, @@ -48,7 +49,28 @@ where .collect(); // Remap the keys (replace the keys in the map with the new keys) - let tensors = remap(tensors, key_remap); + let (tensors, remapped_keys) = remap(tensors, key_remap); + + // Print the remapped keys if debug is enabled + if debug { + let mut remapped_keys = remapped_keys; + remapped_keys.sort(); + println!("Debug information of keys and tensor shapes:\n---"); + for (new_key, old_key) in remapped_keys { + if old_key != new_key { + println!("Original Key: {old_key}"); + println!("Remapped Key: {new_key}"); + } else { + println!("Key: {}", new_key); + } + + let shape = tensors[&new_key].shape(); + let dtype = tensors[&new_key].dtype(); + println!("Shape: {shape:?}"); + println!("Dtype: {dtype:?}"); + println!("---"); + } + } // Convert the vector of Candle tensors to a nested value data structure let nested_value = unflatten::(tensors)?; diff --git a/crates/burn-import/src/pytorch/recorder.rs b/crates/burn-import/src/pytorch/recorder.rs index e1888c849..170f64a9d 100644 --- a/crates/burn-import/src/pytorch/recorder.rs +++ b/crates/burn-import/src/pytorch/recorder.rs @@ -48,6 +48,7 @@ impl Recorder for PyTorchFileRecorder &args.file, args.key_remap, args.top_level_key.as_deref(), // Convert Option to Option<&str> + args.debug, )?; Ok(R::from_item(item, device)) } @@ -92,10 +93,13 @@ pub struct LoadArgs { /// Top-level key to load state_dict from the file. /// Sometimes the state_dict is nested under a top-level key in a dict. pub top_level_key: Option, + + /// Whether to print debug information. + pub debug: bool, } impl LoadArgs { - /// Create a new `LoadArgs` instance. + /// Creates a new `LoadArgs` instance. /// /// # Arguments /// @@ -105,10 +109,11 @@ impl LoadArgs { file, key_remap: Vec::new(), top_level_key: None, + debug: false, } } - /// Set key remapping. + /// Sets key remapping. /// /// # Arguments /// @@ -125,7 +130,7 @@ impl LoadArgs { self } - /// Set top-level key to load state_dict from the file. + /// Sets the top-level key to load state_dict from the file. /// Sometimes the state_dict is nested under a top-level key in a dict. /// /// # Arguments @@ -135,6 +140,12 @@ impl LoadArgs { self.top_level_key = Some(key.into()); self } + + /// Sets printing debug information on. + pub fn with_debug_print(mut self) -> Self { + self.debug = true; + self + } } impl From for LoadArgs {