Add Enum module support in PyTorchFileRecorder (#1436)

* Add Enum module support in PyTorchFileRecorder

Fixes #1431

* Fix wording/typos per PR feedback
This commit is contained in:
Dilshod Tadjibaev 2024-03-11 11:21:01 -05:00 committed by GitHub
parent 9d4fbc5a35
commit 0138e16af6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 445 additions and 8 deletions

View File

@ -347,6 +347,58 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.expect("Should decode state successfully")
```
### Models containing enum modules
Burn supports models containing enum modules with new-type variants (tuple with one item). Importing
weights for such models is automatically supported by the PyTorchFileRecorder. However, it should be
noted that since the source weights file does not contain the enum variant information, the enum
variant is picked based on the enum variant type. Let's consider the following example:
```rust
#[derive(Module, Debug)]
pub enum Conv<B: Backend> {
DwsConv(DwsConv<B>),
Conv(Conv2d<B>),
}
#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
dconv: Conv2d<B>,
pconv: Conv2d<B>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv: Conv<B>,
}
```
If the source weights file contains weights for `DwsConv`, such as the following keys:
```text
---
Key: conv.dconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.dconv.weight
Shape: [2, 1, 3, 3]
Dtype: F32
---
Key: conv.pconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.pconv.weight
Shape: [2, 2, 1, 1]
Dtype: F32
```
The weights will be imported into the `DwsConv` variant of the `Conv` enum module.
If the variant types are identical, then the first variant is picked. Generally, it won't be a
problem since the variant types are usually different.
## Current known issues
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).

View File

@ -1,8 +1,10 @@
use core::ptr;
use std::collections::HashMap;
use super::data::NestedValue;
use super::{adapter::BurnModuleAdapter, error::Error};
use serde::de::{EnumAccess, VariantAccess};
use serde::{
de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
forward_to_deserialize_any,
@ -313,16 +315,65 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
unimplemented!("deserialize_tuple_struct is not implemented")
}
/// Deserializes an enum by attempting to match its variants against the provided data.
///
/// This function attempts to deserialize an enum by iterating over its possible variants
/// and trying to deserialize the data into each until one succeeds. We need to do this
/// because we don't have a way to know which variant to deserialize from the data.
///
/// This is similar to Serde's
/// [untagged enum deserialization](https://serde.rs/enum-representations.html#untagged),
/// but it's on the deserializer side. Using `#[serde(untagged)]` on the enum will force
/// using `deserialize_any`, which is not what we want because we want to use methods, such
/// as `visit_struct`. Also we do not wish to use auto generate code for Deserialize just
/// for enums because it will affect other serialization and deserialization, such
/// as JSON and Bincode.
///
/// # Safety
/// The function uses an unsafe block to clone the `visitor`. This is necessary because
/// the `Visitor` trait does not have a `Clone` implementation, and we need to clone it
/// as we are going to use it multiple times. The Visitor is a code generated unit struct
/// with no states or mutations, so it is safe to clone it in this case. We mainly care
/// about the `visit_enum` method, which is the only method that will be called on the
/// cloned visitor.
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_enum is not implemented")
fn clone_unsafely<T>(thing: &T) -> T {
unsafe {
// Allocate memory for the clone.
let clone = ptr::null_mut();
// Correcting pointer usage based on feedback
let clone = ptr::addr_of_mut!(*clone);
// Copy the memory
ptr::copy_nonoverlapping(thing as *const T, clone, 1);
// Transmute the cloned data pointer into an owned instance of T.
ptr::read(clone)
}
}
// Try each variant in order
for &variant in variants {
// clone visitor to avoid moving it
let cloned_visitor = clone_unsafely(&visitor);
let result = cloned_visitor.visit_enum(ProbeEnumAccess::<A>::new(
self.value.clone().unwrap(),
variant.to_owned(),
self.default_for_missing_fields,
));
if result.is_ok() {
return result;
}
}
Err(de::Error::custom("No variant match"))
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
@ -431,6 +482,82 @@ where
}
}
struct ProbeEnumAccess<A: BurnModuleAdapter> {
value: NestedValue,
current_variant: String,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}
impl<A: BurnModuleAdapter> ProbeEnumAccess<A> {
fn new(value: NestedValue, current_variant: String, default_for_missing_fields: bool) -> Self {
ProbeEnumAccess {
value,
current_variant,
default_for_missing_fields,
phantom: std::marker::PhantomData,
}
}
}
impl<'de, A> EnumAccess<'de> for ProbeEnumAccess<A>
where
A: BurnModuleAdapter,
{
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
seed.deserialize(self.current_variant.clone().into_deserializer())
.map(|v| (v, self))
}
}
impl<'de, A> VariantAccess<'de> for ProbeEnumAccess<A>
where
A: BurnModuleAdapter,
{
type Error = Error;
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
let value = seed.deserialize(
NestedValueWrapper::<A>::new(self.value, self.default_for_missing_fields)
.into_deserializer(),
)?;
Ok(value)
}
fn unit_variant(self) -> Result<(), Self::Error> {
unimplemented!("unit variant is not implemented because it is not used in the burn module")
}
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("tuple variant is not implemented because it is not used in the burn module")
}
fn struct_variant<V>(
self,
_fields: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!(
"struct variant is not implemented because it is not used in the burn module"
)
}
}
/// A wrapper for the nested value data structure with a burn module adapter.
struct NestedValueWrapper<A: BurnModuleAdapter> {
value: NestedValue,
@ -601,11 +728,14 @@ impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
where
V: Visitor<'de>,
{
panic!(
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
name,
);
// Return an error if the originator field name is not set
Err(Error::Other(
format!(
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
name,
)
))
}
fn deserialize_tuple_struct<V>(

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor
class DwsConv(nn.Module):
"""Depthwise separable convolution."""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
super().__init__()
# Depthwise conv
self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
# Pointwise conv
self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x: Tensor) -> Tensor:
x = self.dconv(x)
return self.pconv(x)
class Model(nn.Module):
def __init__(self, depthwise: bool = False) -> None:
super().__init__()
self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_false.pt")
input = torch.rand(1, 2, 5, 5)
print("Depthwise is False")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
print("Depthwise is True")
model = Model(depthwise=True).to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_true.pt")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,194 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub enum Conv<B: Backend> {
DwsConv(DwsConv<B>),
Conv(Conv2d<B>),
}
#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
dconv: Conv2d<B>,
pconv: Conv2d<B>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv: Conv<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let conv = match record.conv {
ConvRecord::DwsConv(dws_conv) => {
let dconv = Conv2dConfig::new([2, 2], [3, 3])
.with_groups(2)
.init_with(dws_conv.dconv);
let pconv = Conv2dConfig::new([2, 2], [1, 1])
.with_groups(1)
.init_with(dws_conv.pconv);
Conv::DwsConv(DwsConv { dconv, pconv })
}
ConvRecord::Conv(conv) => {
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]);
Conv::Conv(conv2d_config.init_with(conv))
}
};
Net { conv }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
match &self.conv {
Conv::DwsConv(dws_conv) => {
let x = dws_conv.dconv.forward(x);
dws_conv.pconv.forward(x)
}
Conv::Conv(conv) => conv.forward(x),
}
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
use super::*;
#[test]
fn depthwise_false() {
let device = Default::default();
let load_args =
LoadArgs::new("tests/enum_module/enum_depthwise_false.pt".into()).with_debug_print();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
[
0.804_481_1,
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
],
],
[
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
[
0.03694397,
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
],
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[0.35449377, -0.02832414, 0.490_976_1],
[0.29709217, 0.332_586_3, 0.30594018],
[0.18101373, 0.30932188, 0.30558896],
],
[
[-0.17683622, -0.13244139, -0.05608707],
[0.23467252, -0.07038684, 0.255_044_1],
[-0.241_931_3, -0.20476191, -0.14468731],
],
]],
&device,
);
output.to_data().assert_approx_eq(&expected.to_data(), 7);
}
#[test]
fn depthwise_true() {
let device = Default::default();
let load_args =
LoadArgs::new("tests/enum_module/enum_depthwise_true.pt".into()).with_debug_print();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
[
0.804_481_1,
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
],
],
[
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
[
0.03694397,
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
],
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[0.77874625, 0.859_017_6, 0.834_283_5],
[0.773_056_4, 0.73817325, 0.78292674],
[0.710_775_2, 0.747_187_2, 0.733_264_4],
],
[
[-0.44891885, -0.49027523, -0.394_170_7],
[-0.43836114, -0.33961445, -0.387_311_5],
[-0.581_134_3, -0.34197026, -0.535_035_7],
],
]],
&device,
);
output.to_data().assert_approx_eq(&expected.to_data(), 7);
}
}

View File

@ -8,6 +8,7 @@ mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod embedding;
mod enum_module;
mod group_norm;
mod integer;
mod key_remap;