Fix chain pattern matching when multiple patterns are provided (#1273)

This commit is contained in:
Guillaume Lagrange 2024-02-07 16:51:48 -05:00 committed by GitHub
parent 419e53bc42
commit c603c68258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 244 additions and 2 deletions

View File

@ -175,8 +175,10 @@ pub fn remap<T>(
for (name, tensor) in tensors.drain() {
let mut new_name = name.clone();
for (pattern, replacement) in &key_remap {
if pattern.is_match(&name) {
new_name = pattern.replace_all(&name, replacement.as_str()).to_string();
if pattern.is_match(&new_name) {
new_name = pattern
.replace_all(&new_name, replacement.as_str())
.to_string();
}
}
remapped.insert(new_name, tensor);

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor
class ConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, 3, bias=False)
self.bn = nn.BatchNorm2d(6)
self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6))
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.layer(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(42)
model = Model()
input = torch.rand(1, 3, 4, 4)
model(input) # condition batch norm
model.eval()
with torch.no_grad():
print(f"Input shape: {input.shape}")
print("Input type: {}", input.dtype)
print(f"Input: {input}")
output = model(input)
print(f"Output: {output}")
print(f"Output Shape: {output.shape}")
torch.save(model.state_dict(), "key_remap.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,182 @@
use std::marker::PhantomData;
use burn::{
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig,
},
tensor::{backend::Backend, Device, Tensor},
};
/// Some module that implements a specific method so it can be used in a sequential block.
pub trait ForwardModule<B: Backend> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;
}
/// Conv2d + BatchNorm block.
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv2d<B>,
bn: BatchNorm<B, 2>,
}
impl<B: Backend> ForwardModule<B> for ConvBlock<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let out = self.conv.forward(input);
self.bn.forward(out)
}
}
impl<B: Backend> ConvBlock<B> {
pub fn new(in_channels: usize, out_channels: usize, device: &Device<B>) -> Self {
let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1])
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(out_channels).init(device);
Self { conv, bn }
}
}
/// Collection of sequential blocks.
#[derive(Module, Debug)]
pub struct ModuleBlock<B: Backend, M> {
blocks: Vec<M>,
_backend: PhantomData<B>,
}
impl<B: Backend, M: ForwardModule<B>> ModuleBlock<B, M> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let mut out = input;
for block in &self.blocks {
out = block.forward(out);
}
out
}
}
impl<B: Backend> ModuleBlock<B, ConvBlock<B>> {
pub fn new(device: &Device<B>) -> Self {
let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)];
Self {
blocks,
_backend: PhantomData,
}
}
}
#[derive(Module, Debug)]
pub struct Model<B: Backend, M> {
conv: Conv2d<B>,
bn: BatchNorm<B, 2>,
layer: ModuleBlock<B, M>,
}
impl<B: Backend> Model<B, ConvBlock<B>> {
pub fn new(device: &Device<B>) -> Self {
let conv = Conv2dConfig::new([3, 6], [3, 3])
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(6).init(device);
let layer = ModuleBlock::new(device);
Self { conv, bn, layer }
}
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let out = self.conv.forward(input);
let out = self.bn.forward(out);
self.layer.forward(out)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
use super::*;
#[test]
#[should_panic]
fn key_remap_chained_missing_pattern() {
// Loading record should fail due to missing pattern to map the layer.blocks
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap_chained/key_remap.pt".into())
// Map *.block.0.* -> *.conv.*
.with_key_remap("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
// Map *.block.1.* -> *.bn.*
.with_key_remap("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model: Model<Backend, _> = Model::new(&device);
model.load_record(record);
}
#[test]
fn key_remap_chained() {
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap_chained/key_remap.pt".into())
// Map *.block.0.* -> *.conv.*
.with_key_remap("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
// Map *.block.1.* -> *.bn.*
.with_key_remap("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2")
// Map layer.[i].* -> layer.blocks.[i].*
.with_key_remap("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model: Model<Backend, _> = Model::new(&device);
let model = model.load_record(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[
[0.76193494, 0.626_546_1, 0.49510366, 0.11974698],
[0.07161391, 0.03232569, 0.704_681, 0.254_516],
[0.399_373_7, 0.21224737, 0.40888822, 0.14808255],
[0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6],
],
[
[0.33959562, 0.13321638, 0.41178054, 0.257_626_3],
[0.347_029_2, 0.02400219, 0.77974546, 0.15189773],
[0.75130886, 0.726_892_1, 0.85721636, 0.11647397],
[0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734],
],
[
[0.42948407, 0.49613327, 0.38488472, 0.08250773],
[0.73995143, 0.00364107, 0.81039995, 0.87411255],
[0.972_853_2, 0.38206023, 0.08917904, 0.61241513],
[0.77621365, 0.00234562, 0.38650817, 0.20027226],
],
]],
&device,
);
let expected = Tensor::<Backend, 4>::from_data(
[[
[[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]],
[[0.17582723, 0.11344293], [0.05444185, 0.13307181]],
[[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]],
[[0.00230906, -0.02177845], [0.01129148, 0.00925517]],
[[0.14751078, 0.14433631], [0.05498439, 0.29049855]],
[[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]],
]],
&device,
);
let output = model.forward(input);
output.to_data().assert_approx_eq(&expected.to_data(), 7);
}
}

View File

@ -14,6 +14,7 @@ cfg_if::cfg_if! {
mod group_norm;
mod integer;
mod key_remap;
mod key_remap_chained;
mod layer_norm;
mod linear;
}