From c603c68258e9fa7129373c2cf8936a9008981095 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 7 Feb 2024 16:51:48 -0500 Subject: [PATCH] Fix chain pattern matching when multiple patterns are provided (#1273) --- burn-core/src/record/serde/data.rs | 6 +- .../tests/key_remap_chained/export_weights.py | 57 ++++++ .../tests/key_remap_chained/key_remap.pt | Bin 0 -> 7052 bytes .../tests/key_remap_chained/mod.rs | 182 ++++++++++++++++++ burn-import/pytorch-tests/tests/mod.rs | 1 + 5 files changed, 244 insertions(+), 2 deletions(-) create mode 100755 burn-import/pytorch-tests/tests/key_remap_chained/export_weights.py create mode 100644 burn-import/pytorch-tests/tests/key_remap_chained/key_remap.pt create mode 100644 burn-import/pytorch-tests/tests/key_remap_chained/mod.rs diff --git a/burn-core/src/record/serde/data.rs b/burn-core/src/record/serde/data.rs index cca6d5047..0bdbe6abe 100644 --- a/burn-core/src/record/serde/data.rs +++ b/burn-core/src/record/serde/data.rs @@ -175,8 +175,10 @@ pub fn remap( for (name, tensor) in tensors.drain() { let mut new_name = name.clone(); for (pattern, replacement) in &key_remap { - if pattern.is_match(&name) { - new_name = pattern.replace_all(&name, replacement.as_str()).to_string(); + if pattern.is_match(&new_name) { + new_name = pattern + .replace_all(&new_name, replacement.as_str()) + .to_string(); } } remapped.insert(new_name, tensor); diff --git a/burn-import/pytorch-tests/tests/key_remap_chained/export_weights.py b/burn-import/pytorch-tests/tests/key_remap_chained/export_weights.py new file mode 100755 index 000000000..95da5a495 --- /dev/null +++ b/burn-import/pytorch-tests/tests/key_remap_chained/export_weights.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +import torch +from torch import nn, Tensor + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 6, 3, bias=False) + self.bn = nn.BatchNorm2d(6) + self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6)) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.layer(x) + + return x + + +def main(): + torch.set_printoptions(precision=8) + torch.manual_seed(42) + + model = Model() + + input = torch.rand(1, 3, 4, 4) + model(input) # condition batch norm + model.eval() + + with torch.no_grad(): + print(f"Input shape: {input.shape}") + print("Input type: {}", input.dtype) + print(f"Input: {input}") + output = model(input) + + print(f"Output: {output}") + print(f"Output Shape: {output.shape}") + + torch.save(model.state_dict(), "key_remap.pt") + + +if __name__ == "__main__": + main() diff --git a/burn-import/pytorch-tests/tests/key_remap_chained/key_remap.pt b/burn-import/pytorch-tests/tests/key_remap_chained/key_remap.pt new file mode 100644 index 0000000000000000000000000000000000000000..8c2b11b39d470c3ef74a52aac38835f2cd9d6177 GIT binary patch literal 7052 zcmb_h3s_WD9v?s+BH}2DsGul52IDagUZZo*6tWXJrZ}!SNk+pk4in?RnYs9?Suq8b z%1BFjm>OzfxoKIZ!Z}9vt*&BfeYVwCt%_SAZuMogWslu+<_>u8oxw)h%XdC5_uk*{ zch3KH9u(+921q0k5fYzkm_#Zmv#cm`T9%k9(o0ONDXpT+uFH**EaY8P!De&0-EJ|n zwsMCna*nga;77mo@$LI%=`yduZZYZUl8pOgV0~~A|NMOZ|o_ul3nv=jBxAIYW~s>u zCkt3I;1rHU=4DByEK@m_J6M=vfT_kS>Aq^Q`$noBLNBGzd#msv_+8pq~!!mM(K zwbuvf9v@^vIT0WRRuF;YWR6(LLX`n#c!=$$6&7cjJgwMXZZ0Fgy`hjP=w=qoA}kYG zm`zxy2gCG1lr)FllRJ8^=D4S^aJm6%{BSFJxETR=HUyjrv|bb3NlkDU$Ck^&83veV z98OFi?5TqD?Vki(IZ(%O-Rh<}(l)>ang&f|;aW3$38Y$=f9Opt778>9K#w0(S{$gN}mpk3) z;a0#02{&;_5#jFf_(Lo-8DKHbE5v?6kDG;ERKgNMLtJhlG(49t=DFO;b(mBRn*l!T zhrHM2WkPZlv~%NM;vK((j=!9vt6(8CKqt@VJ$?21+$AK-fGkH==_OlA$(C_s%UQU> z09W#yev?e#^hboGneb7L)a@nx3ng8}kyf#AwE;fH^SUo7_4*niXBJ$`aaMadYbfVB zj&nT=HyB_o&+9ii{dxUyFSpn0+3*R%O}zdj;qLMJMixG0fOVB___U{@k~+qk=zIY; zd8+MZnUTbp1U}=2&w486Qj60?W(&AQALKHU*~c?Gz^!grU+Q^81FYBq8{DvwJP39f zL&*bAUcx3f+{S%Grw4DY!spy@JNI<(&8J-C!5wb6lY2+2tuVXg@Od|E?(3}r?sCK3 z12q`#fruuHv);eR#y%U7Ome zGHnHY&eJiHDZF{Qcm-=Ir<1TQH%Ty2^X6>+t=-15b_-RjW;#cO5fjnYx0+O<;i*wC zr^|J@Ie|#=Wqa>meJK_+Je>qSdN~I*g=d13hf>hcq2m}s=P_hi{smKgBo*j3OL6ne zJHX2IG0g7zp}@T6g2oYrnf&_C+Pbce1lsT}&DsCHiQ>P?XMWrI2Kwdl;kFSYcB1!> z27`*H#^7xS-$3WKJcz#g`5J2Pct#VwI~UjdqyUSoQ*hP9I506ig!$sTRp|AN*O~6u zzt>c%$AS3!_h~+ESdM-=J&)Obd_GcN+{0}AHXWS0(8-j{-iNkScQVm~1~YQmNoM=B zZ&1zf8F5_cXP4`xT1nGf&U#YEk^f_c+g*gAXKXfP!-3V*jwr}<)a82&}- zVy?BWLl4wKCV2WpwAPsf6ir8Te^yRO4;ML2+nY61;X3LpuEO7+kF8Lbt zRm~5YL)KJ$;=R2{{e3qYb#~CSZ(BB@q*V*sQffo7IWHWa|K?w)uB!%pT{{L>X(oYo z7*A}U49;zL{*j=Rp8+J)C-t!(K zSzF7TkEvqL#$G_WL=`Adr!q7H+OL1`$$C&kT@As;Opkcml{>@Iu?A$Vc%#$5-KpHLci{7lETv2ZPwVtHFujyTFO#hw)*{BRKr}8q}dY z1+H9y%zzp~ayo)12_vL8@8yz$rQx;kP5+HPeBg^wLny4=*;vxC3bqtE_W{AC2} ze=7>b3s)Gm4bad>e_l*q=2dIAbjWaj;r=JzD{>;K9 zTIUhFcI!rmR&&X!joZIWi%xFv*#D`JYYF|+&kk7s_VZ6F@u`0#{rX@3{JP&$4+8VQ zR5DeZ2P6$92bc!OrbRO|HLNDs_DbvAg;G>Lxdf$e+TBY1-*D{N9lvjXKl&v61J#dg zvTENgv}#|fuxh6sS*%UI->RKE*6Q*9IY&6tFMSgJf$INX{7=HYzy6oSFJLaJ4&#IF zba3R8m8fK`7Ei4`2(DO@n)Wuc`TGUnm8PwX_2O`RhjJmf??54Fo*NGeyEmbqEuW%!SCq`JuWd$y zWj~>=rfOhqJ_wH1PD1CDttc<^`1IEOC)%jrFTc2T@$c7PsYEG`JCb&30mfa=JoLxw zub=qK3EcnwvyYpP`<;L4!1YPJrFB(; z8LG_89F2Lmqq!C`>6AmYeIOvAGlJi&L6rEzLT70h=;Bc!}&=ngv-de zuo#{phV$1rLU@ZGc%~T6U!n-%GIHD^R?iZ{`?h?GA9%JH&L6udYRI1=bsxj?r*|Q` z#TV^kcmCuoM9YXFMA|-P=MP+nZt+F?Se-vX3DI#R_eI)1M&}2d5S>UeUxdC5o45F) zN!b#n8o#6qwPob+U8JoRqxof87;P;Aw09B^!};|yQBWp^jo^#>uj~c-u&@*_YHUbq vFP{G2gkC+ro%FPlT!#rKy$Ft;9!jRt(|dB!6<2{ilpeDYOss@{&wcwpH { + fn forward(&self, input: Tensor) -> Tensor; +} + +/// Conv2d + BatchNorm block. +#[derive(Module, Debug)] +pub struct ConvBlock { + conv: Conv2d, + bn: BatchNorm, +} + +impl ForwardModule for ConvBlock { + fn forward(&self, input: Tensor) -> Tensor { + let out = self.conv.forward(input); + self.bn.forward(out) + } +} + +impl ConvBlock { + pub fn new(in_channels: usize, out_channels: usize, device: &Device) -> Self { + let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1]) + .with_bias(false) + .init(device); + let bn = BatchNormConfig::new(out_channels).init(device); + + Self { conv, bn } + } +} + +/// Collection of sequential blocks. +#[derive(Module, Debug)] +pub struct ModuleBlock { + blocks: Vec, + _backend: PhantomData, +} + +impl> ModuleBlock { + pub fn forward(&self, input: Tensor) -> Tensor { + let mut out = input; + for block in &self.blocks { + out = block.forward(out); + } + out + } +} + +impl ModuleBlock> { + pub fn new(device: &Device) -> Self { + let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)]; + + Self { + blocks, + _backend: PhantomData, + } + } +} + +#[derive(Module, Debug)] +pub struct Model { + conv: Conv2d, + bn: BatchNorm, + layer: ModuleBlock, +} + +impl Model> { + pub fn new(device: &Device) -> Self { + let conv = Conv2dConfig::new([3, 6], [3, 3]) + .with_bias(false) + .init(device); + let bn = BatchNormConfig::new(6).init(device); + + let layer = ModuleBlock::new(device); + + Self { conv, bn, layer } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let out = self.conv.forward(input); + let out = self.bn.forward(out); + self.layer.forward(out) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; + + use super::*; + + #[test] + #[should_panic] + fn key_remap_chained_missing_pattern() { + // Loading record should fail due to missing pattern to map the layer.blocks + let device = Default::default(); + let load_args = LoadArgs::new("tests/key_remap_chained/key_remap.pt".into()) + // Map *.block.0.* -> *.conv.* + .with_key_remap("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2") + // Map *.block.1.* -> *.bn.* + .with_key_remap("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2"); + + let record = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model: Model = Model::new(&device); + + model.load_record(record); + } + + #[test] + fn key_remap_chained() { + let device = Default::default(); + let load_args = LoadArgs::new("tests/key_remap_chained/key_remap.pt".into()) + // Map *.block.0.* -> *.conv.* + .with_key_remap("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2") + // Map *.block.1.* -> *.bn.* + .with_key_remap("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2") + // Map layer.[i].* -> layer.blocks.[i].* + .with_key_remap("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2"); + + let record = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model: Model = Model::new(&device); + + let model = model.load_record(record); + + let input = Tensor::::from_data( + [[ + [ + [0.76193494, 0.626_546_1, 0.49510366, 0.11974698], + [0.07161391, 0.03232569, 0.704_681, 0.254_516], + [0.399_373_7, 0.21224737, 0.40888822, 0.14808255], + [0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6], + ], + [ + [0.33959562, 0.13321638, 0.41178054, 0.257_626_3], + [0.347_029_2, 0.02400219, 0.77974546, 0.15189773], + [0.75130886, 0.726_892_1, 0.85721636, 0.11647397], + [0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734], + ], + [ + [0.42948407, 0.49613327, 0.38488472, 0.08250773], + [0.73995143, 0.00364107, 0.81039995, 0.87411255], + [0.972_853_2, 0.38206023, 0.08917904, 0.61241513], + [0.77621365, 0.00234562, 0.38650817, 0.20027226], + ], + ]], + &device, + ); + let expected = Tensor::::from_data( + [[ + [[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]], + [[0.17582723, 0.11344293], [0.05444185, 0.13307181]], + [[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]], + [[0.00230906, -0.02177845], [0.01129148, 0.00925517]], + [[0.14751078, 0.14433631], [0.05498439, 0.29049855]], + [[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]], + ]], + &device, + ); + + let output = model.forward(input); + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } +} diff --git a/burn-import/pytorch-tests/tests/mod.rs b/burn-import/pytorch-tests/tests/mod.rs index 136ed44d6..daa0b30a6 100644 --- a/burn-import/pytorch-tests/tests/mod.rs +++ b/burn-import/pytorch-tests/tests/mod.rs @@ -14,6 +14,7 @@ cfg_if::cfg_if! { mod group_norm; mod integer; mod key_remap; + mod key_remap_chained; mod layer_norm; mod linear; }