mirror of https://github.com/tracel-ai/burn.git
Add Enum module support in PyTorchFileRecorder (#1436)
* Add Enum module support in PyTorchFileRecorder Fixes #1431 * Fix wording/typos per PR feedback
This commit is contained in:
parent
9d4fbc5a35
commit
0138e16af6
|
@ -347,6 +347,58 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
|||
.expect("Should decode state successfully")
|
||||
```
|
||||
|
||||
### Models containing enum modules
|
||||
|
||||
Burn supports models containing enum modules with new-type variants (tuple with one item). Importing
|
||||
weights for such models is automatically supported by the PyTorchFileRecorder. However, it should be
|
||||
noted that since the source weights file does not contain the enum variant information, the enum
|
||||
variant is picked based on the enum variant type. Let's consider the following example:
|
||||
|
||||
```rust
|
||||
#[derive(Module, Debug)]
|
||||
pub enum Conv<B: Backend> {
|
||||
DwsConv(DwsConv<B>),
|
||||
Conv(Conv2d<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DwsConv<B: Backend> {
|
||||
dconv: Conv2d<B>,
|
||||
pconv: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv: Conv<B>,
|
||||
}
|
||||
```
|
||||
|
||||
If the source weights file contains weights for `DwsConv`, such as the following keys:
|
||||
|
||||
```text
|
||||
---
|
||||
Key: conv.dconv.bias
|
||||
Shape: [2]
|
||||
Dtype: F32
|
||||
---
|
||||
Key: conv.dconv.weight
|
||||
Shape: [2, 1, 3, 3]
|
||||
Dtype: F32
|
||||
---
|
||||
Key: conv.pconv.bias
|
||||
Shape: [2]
|
||||
Dtype: F32
|
||||
---
|
||||
Key: conv.pconv.weight
|
||||
Shape: [2, 2, 1, 1]
|
||||
Dtype: F32
|
||||
```
|
||||
|
||||
The weights will be imported into the `DwsConv` variant of the `Conv` enum module.
|
||||
|
||||
If the variant types are identical, then the first variant is picked. Generally, it won't be a
|
||||
problem since the variant types are usually different.
|
||||
|
||||
## Current known issues
|
||||
|
||||
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use core::ptr;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::data::NestedValue;
|
||||
use super::{adapter::BurnModuleAdapter, error::Error};
|
||||
|
||||
use serde::de::{EnumAccess, VariantAccess};
|
||||
use serde::{
|
||||
de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
|
||||
forward_to_deserialize_any,
|
||||
|
@ -313,16 +315,65 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
|
|||
unimplemented!("deserialize_tuple_struct is not implemented")
|
||||
}
|
||||
|
||||
/// Deserializes an enum by attempting to match its variants against the provided data.
|
||||
///
|
||||
/// This function attempts to deserialize an enum by iterating over its possible variants
|
||||
/// and trying to deserialize the data into each until one succeeds. We need to do this
|
||||
/// because we don't have a way to know which variant to deserialize from the data.
|
||||
///
|
||||
/// This is similar to Serde's
|
||||
/// [untagged enum deserialization](https://serde.rs/enum-representations.html#untagged),
|
||||
/// but it's on the deserializer side. Using `#[serde(untagged)]` on the enum will force
|
||||
/// using `deserialize_any`, which is not what we want because we want to use methods, such
|
||||
/// as `visit_struct`. Also we do not wish to use auto generate code for Deserialize just
|
||||
/// for enums because it will affect other serialization and deserialization, such
|
||||
/// as JSON and Bincode.
|
||||
///
|
||||
/// # Safety
|
||||
/// The function uses an unsafe block to clone the `visitor`. This is necessary because
|
||||
/// the `Visitor` trait does not have a `Clone` implementation, and we need to clone it
|
||||
/// as we are going to use it multiple times. The Visitor is a code generated unit struct
|
||||
/// with no states or mutations, so it is safe to clone it in this case. We mainly care
|
||||
/// about the `visit_enum` method, which is the only method that will be called on the
|
||||
/// cloned visitor.
|
||||
fn deserialize_enum<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variants: &'static [&'static str],
|
||||
_visitor: V,
|
||||
variants: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_enum is not implemented")
|
||||
fn clone_unsafely<T>(thing: &T) -> T {
|
||||
unsafe {
|
||||
// Allocate memory for the clone.
|
||||
let clone = ptr::null_mut();
|
||||
// Correcting pointer usage based on feedback
|
||||
let clone = ptr::addr_of_mut!(*clone);
|
||||
// Copy the memory
|
||||
ptr::copy_nonoverlapping(thing as *const T, clone, 1);
|
||||
// Transmute the cloned data pointer into an owned instance of T.
|
||||
ptr::read(clone)
|
||||
}
|
||||
}
|
||||
|
||||
// Try each variant in order
|
||||
for &variant in variants {
|
||||
// clone visitor to avoid moving it
|
||||
let cloned_visitor = clone_unsafely(&visitor);
|
||||
let result = cloned_visitor.visit_enum(ProbeEnumAccess::<A>::new(
|
||||
self.value.clone().unwrap(),
|
||||
variant.to_owned(),
|
||||
self.default_for_missing_fields,
|
||||
));
|
||||
|
||||
if result.is_ok() {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Err(de::Error::custom("No variant match"))
|
||||
}
|
||||
|
||||
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
|
@ -431,6 +482,82 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct ProbeEnumAccess<A: BurnModuleAdapter> {
|
||||
value: NestedValue,
|
||||
current_variant: String,
|
||||
default_for_missing_fields: bool,
|
||||
phantom: std::marker::PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<A: BurnModuleAdapter> ProbeEnumAccess<A> {
|
||||
fn new(value: NestedValue, current_variant: String, default_for_missing_fields: bool) -> Self {
|
||||
ProbeEnumAccess {
|
||||
value,
|
||||
current_variant,
|
||||
default_for_missing_fields,
|
||||
phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, A> EnumAccess<'de> for ProbeEnumAccess<A>
|
||||
where
|
||||
A: BurnModuleAdapter,
|
||||
{
|
||||
type Error = Error;
|
||||
type Variant = Self;
|
||||
|
||||
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
|
||||
where
|
||||
V: DeserializeSeed<'de>,
|
||||
{
|
||||
seed.deserialize(self.current_variant.clone().into_deserializer())
|
||||
.map(|v| (v, self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, A> VariantAccess<'de> for ProbeEnumAccess<A>
|
||||
where
|
||||
A: BurnModuleAdapter,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
let value = seed.deserialize(
|
||||
NestedValueWrapper::<A>::new(self.value, self.default_for_missing_fields)
|
||||
.into_deserializer(),
|
||||
)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn unit_variant(self) -> Result<(), Self::Error> {
|
||||
unimplemented!("unit variant is not implemented because it is not used in the burn module")
|
||||
}
|
||||
|
||||
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("tuple variant is not implemented because it is not used in the burn module")
|
||||
}
|
||||
|
||||
fn struct_variant<V>(
|
||||
self,
|
||||
_fields: &'static [&'static str],
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!(
|
||||
"struct variant is not implemented because it is not used in the burn module"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for the nested value data structure with a burn module adapter.
|
||||
struct NestedValueWrapper<A: BurnModuleAdapter> {
|
||||
value: NestedValue,
|
||||
|
@ -601,11 +728,14 @@ impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
|
|||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
panic!(
|
||||
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
|
||||
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
|
||||
name,
|
||||
);
|
||||
// Return an error if the originator field name is not set
|
||||
Err(Error::Other(
|
||||
format!(
|
||||
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
|
||||
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
|
||||
name,
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
fn deserialize_tuple_struct<V>(
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env python3
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
class DwsConv(nn.Module):
|
||||
"""Depthwise separable convolution."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
|
||||
super().__init__()
|
||||
# Depthwise conv
|
||||
self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
|
||||
# Pointwise conv
|
||||
self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.dconv(x)
|
||||
return self.pconv(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, depthwise: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "enum_depthwise_false.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
|
||||
print("Depthwise is False")
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
print("Depthwise is True")
|
||||
model = Model(depthwise=True).to(torch.device("cpu"))
|
||||
torch.save(model.state_dict(), "enum_depthwise_true.pt")
|
||||
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,194 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub enum Conv<B: Backend> {
|
||||
DwsConv(DwsConv<B>),
|
||||
Conv(Conv2d<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DwsConv<B: Backend> {
|
||||
dconv: Conv2d<B>,
|
||||
pconv: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv: Conv<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let conv = match record.conv {
|
||||
ConvRecord::DwsConv(dws_conv) => {
|
||||
let dconv = Conv2dConfig::new([2, 2], [3, 3])
|
||||
.with_groups(2)
|
||||
.init_with(dws_conv.dconv);
|
||||
let pconv = Conv2dConfig::new([2, 2], [1, 1])
|
||||
.with_groups(1)
|
||||
.init_with(dws_conv.pconv);
|
||||
Conv::DwsConv(DwsConv { dconv, pconv })
|
||||
}
|
||||
ConvRecord::Conv(conv) => {
|
||||
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]);
|
||||
Conv::Conv(conv2d_config.init_with(conv))
|
||||
}
|
||||
};
|
||||
Net { conv }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
match &self.conv {
|
||||
Conv::DwsConv(dws_conv) => {
|
||||
let x = dws_conv.dconv.forward(x);
|
||||
dws_conv.pconv.forward(x)
|
||||
}
|
||||
Conv::Conv(conv) => conv.forward(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn depthwise_false() {
|
||||
let device = Default::default();
|
||||
let load_args =
|
||||
LoadArgs::new("tests/enum_module/enum_depthwise_false.pt".into()).with_debug_print();
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
|
||||
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
|
||||
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
|
||||
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
|
||||
[
|
||||
0.804_481_1,
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
],
|
||||
],
|
||||
[
|
||||
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
|
||||
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
|
||||
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
|
||||
[
|
||||
0.03694397,
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
],
|
||||
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.35449377, -0.02832414, 0.490_976_1],
|
||||
[0.29709217, 0.332_586_3, 0.30594018],
|
||||
[0.18101373, 0.30932188, 0.30558896],
|
||||
],
|
||||
[
|
||||
[-0.17683622, -0.13244139, -0.05608707],
|
||||
[0.23467252, -0.07038684, 0.255_044_1],
|
||||
[-0.241_931_3, -0.20476191, -0.14468731],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected.to_data(), 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn depthwise_true() {
|
||||
let device = Default::default();
|
||||
let load_args =
|
||||
LoadArgs::new("tests/enum_module/enum_depthwise_true.pt".into()).with_debug_print();
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
|
||||
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
|
||||
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
|
||||
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
|
||||
[
|
||||
0.804_481_1,
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
],
|
||||
],
|
||||
[
|
||||
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
|
||||
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
|
||||
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
|
||||
[
|
||||
0.03694397,
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
],
|
||||
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.77874625, 0.859_017_6, 0.834_283_5],
|
||||
[0.773_056_4, 0.73817325, 0.78292674],
|
||||
[0.710_775_2, 0.747_187_2, 0.733_264_4],
|
||||
],
|
||||
[
|
||||
[-0.44891885, -0.49027523, -0.394_170_7],
|
||||
[-0.43836114, -0.33961445, -0.387_311_5],
|
||||
[-0.581_134_3, -0.34197026, -0.535_035_7],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected.to_data(), 7);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ mod conv2d;
|
|||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod embedding;
|
||||
mod enum_module;
|
||||
mod group_norm;
|
||||
mod integer;
|
||||
mod key_remap;
|
||||
|
|
Loading…
Reference in New Issue