Add support for loading PyTorch `.pt` (weights/states) files directly to model's record (#1085)

This commit is contained in:
Dilshod Tadjibaev 2024-01-25 09:20:09 -06:00 committed by GitHub
parent 4ca3e31601
commit 0368409eb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 4494 additions and 49 deletions

View File

@ -16,6 +16,7 @@ members = [
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-import/pytorch-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
@ -26,6 +27,7 @@ members = [
"burn-train",
"xtask",
"examples/*",
"examples/pytorch-import/model",
"backend-comparison",
]
@ -40,6 +42,9 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies]
async-trait = "0.1.74"
bytemuck = "1.14"
candle-core = { version = "0.3.2" }
clap = "4.4.11"
console_error_panic_hook = "0.1.7"
const-random = "0.1.17"
csv = "1.3.0"
dashmap = "5.5.3"
@ -56,14 +61,18 @@ libm = "0.2.8"
log = { default-features = false, version = "0.4.20" }
pretty_assertions = "1.4"
proc-macro2 = "1.0.69"
protobuf = "3.3"
protobuf-codegen = "3.3"
quote = "1.0.33"
r2d2 = "0.8.10"
r2d2_sqlite = { version = "0.23.0" }
rayon = "1.8.0"
regex = "1.10.2"
reqwest = "0.11.23"
rmp-serde = "1.1.2"
rstest = "0.18.2"
rusqlite = { version = "0.30.0" }
rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0"
serde_rusqlite = "0.34.0"
serde-wasm-bindgen = "0.6.1"
@ -80,8 +89,6 @@ wasm-bindgen = "0.2.88"
wasm-bindgen-futures = "0.4.38"
wasm-logger = "0.2.0"
wasm-timer = "0.2.5"
console_error_panic_hook = "0.1.7"
reqwest = "0.11.23"
# WGPU stuff
@ -90,14 +97,15 @@ pollster = "0.3"
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
wgpu = "0.19.0"
#
# The following packages disable the "std" feature for no_std compatibility
#
bincode = { version = "2.0.0-rc.3", features = [
"alloc",
"serde",
], default-features = false }
derive-new = { version = "0.5.9", default-features = false }
#
# The following packages disable the "std" feature for no_std compatibility
#
derive-new = { version = "0.6.0", default-features = false }
half = { version = "2.3.1", features = [
"alloc",

View File

@ -21,8 +21,8 @@ accelerate = ["candle-core/accelerate"]
derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.12.0", default-features = false }
half = { workspace = true }
candle-core = { workspace = true }
candle-core = { version = "0.3.2" }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.12.0", default-features = false, features = [

View File

@ -62,6 +62,9 @@ tch = ["burn-tch"]
candle = ["burn-candle"]
wgpu = ["burn-wgpu"]
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
@ -105,6 +108,8 @@ bincode = { workspace = true }
half = { workspace = true }
rmp-serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
thiserror = { workspace = true, optional = true }
regex = { workspace = true, optional = true }
[dev-dependencies]
tempfile = { workspace = true }

View File

@ -11,7 +11,7 @@ use burn_tensor::{
use core::marker::PhantomData;
/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new)]
#[derive(Debug, Clone, Copy, new, Default)]
pub struct ConstantRecord;
impl serde::Serialize for ConstantRecord {

View File

@ -200,7 +200,7 @@ mod tests {
module
.gamma
.as_ref()
.expect("Gamma is None")
.expect("gamma should not be None")
.val()
.to_data()
.assert_approx_eq(&Data::ones([6].into()), 3);
@ -208,7 +208,7 @@ mod tests {
module
.beta
.as_ref()
.expect("beta is None")
.expect("beta should not be None")
.val()
.to_data()
.assert_approx_eq(&Data::zeros([6]), 3);

View File

@ -109,6 +109,11 @@ macro_rules! str2writer {
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
let path = $file.as_path();
// Add parent directories if they don't exist
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).ok();
}
if path.exists() {
log::info!("File exists, replacing");
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;

View File

@ -17,3 +17,6 @@ mod file;
pub use file::*;
pub use primitive::ParamSerde;
#[cfg(feature = "record-item-custom-serde")]
pub mod serde;

View File

@ -1,20 +1,22 @@
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use burn_tensor::backend::Backend;
use burn_tensor::Bool;
use burn_tensor::Int;
use burn_tensor::Tensor;
use serde::Deserialize;
use serde::Serialize;
use alloc::{
string::{String, ToString},
vec,
vec::Vec,
};
use core::{fmt, marker::PhantomData};
use super::tensor::BoolTensorSerde;
use super::tensor::FloatTensorSerde;
use super::tensor::IntTensorSerde;
use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
use super::{PrecisionSettings, Record};
use crate::module::{Param, ParamId};
use burn_tensor::{DataSerialize, Element};
use burn_tensor::{backend::Backend, Bool, DataSerialize, Element, Int, Tensor};
use hashbrown::HashMap;
use serde::{
de::{Error, SeqAccess, Visitor},
ser::SerializeTuple,
Deserialize, Serialize,
};
impl<B> Record<B> for ()
where
@ -63,21 +65,22 @@ where
impl<const N: usize, T, B> Record<B> for [T; N]
where
T: Record<B> + core::fmt::Debug,
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
/// The record item is an array of the record item of the elements.
/// The reason why we wrap the array in a struct is because serde does not support
/// deserializing arrays of variable size,
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
/// for backward compatibility reasons. Serde APIs were created before const generics.
type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.map(Record::into_item).into_iter().collect()
Array(self.map(Record::into_item))
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.into_iter()
.map(|i| Record::from_item(i, device))
.collect::<Vec<_>>()
.try_into()
.unwrap_or_else(|_| panic!("An arrar of size {N}"))
item.0.map(|i| Record::from_item(i, device))
}
}
@ -223,3 +226,80 @@ primitive!(i64);
primitive!(i32);
primitive!(i16);
primitive!(i8);
/// A wrapper around an array of size N, so that it can be serialized and deserialized
/// using serde.
///
/// The reason why we wrap the array in a struct is because serde does not support
/// deserializing arrays of variable size,
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
/// for backward compatibility reasons. Serde APIs were created before const generics.
pub struct Array<const N: usize, T>([T; N]);
impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_tuple(self.0.len())?;
for element in &self.0 {
seq.serialize_element(element)?;
}
seq.end()
}
}
impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ArrayVisitor<T, const N: usize> {
marker: PhantomData<T>,
}
impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
where
T: Deserialize<'de>,
{
type Value = Array<N, T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a fixed size array")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut items = vec![];
for i in 0..N {
let item = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(i, &self))?;
items.push(item);
}
let array: [T; N] = items
.into_iter()
.collect::<Vec<_>>()
.try_into()
.map_err(|_| "An array of size {N}")
.unwrap();
Ok(Array(array))
}
}
deserializer.deserialize_tuple(
N,
ArrayVisitor {
marker: PhantomData,
},
)
}
}

View File

@ -158,6 +158,9 @@ pub enum RecorderError {
/// File not found.
FileNotFound(String),
/// Failed to read file.
DeserializeError(String),
/// Other error.
Unknown(String),
}

View File

@ -0,0 +1,71 @@
use super::data::NestedValue;
/// A trait that defines the adapter for a Burn module.
///
/// This is used to adapt an incoming module to a Burn module.
pub trait BurnModuleAdapter: Sized {
/// Adapts a module.
fn adapt(name: &str, data: NestedValue) -> NestedValue {
match name {
"BatchNorm" => Self::adapt_batch_norm(data),
"Conv1d" => Self::adapt_conv1d(data),
"Conv2d" => Self::adapt_conv2d(data),
"ConvTranspose1d" => Self::adapt_conv_transpose_1d(data),
"ConvTranspose2d" => Self::adapt_conv_transpose_2d(data),
"Embedding" => Self::adapt_embedding(data),
"GroupNorm" => Self::adapt_group_norm(data),
"LayerNorm" => Self::adapt_layer_norm(data),
"Linear" => Self::adapt_linear(data),
_ => data,
}
}
/// Adapts a linear module.
fn adapt_linear(data: NestedValue) -> NestedValue {
data
}
/// Adapts a Convolution 1D module.
fn adapt_conv1d(data: NestedValue) -> NestedValue {
data
}
/// Adapts a Convolution 2D module.
fn adapt_conv2d(data: NestedValue) -> NestedValue {
data
}
/// Adapts convolution transpose 1D module.
fn adapt_conv_transpose_1d(data: NestedValue) -> NestedValue {
data
}
/// Adapts convolution transpose 2D module.
fn adapt_conv_transpose_2d(data: NestedValue) -> NestedValue {
data
}
/// Adapts embedding module.
fn adapt_embedding(data: NestedValue) -> NestedValue {
data
}
/// Adapts group normalization module.
fn adapt_group_norm(data: NestedValue) -> NestedValue {
data
}
/// Adapts layer normalization module.
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
data
}
/// Adapts batch normalization module.
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
data
}
}
/// Default adapter that takes no action.
pub struct DefaultAdapter;
impl BurnModuleAdapter for DefaultAdapter {}

View File

@ -0,0 +1,261 @@
use std::collections::HashMap;
use super::adapter::BurnModuleAdapter;
use super::de::Deserializer;
use super::error::Error;
use super::ser::Serializer;
use crate::record::{PrecisionSettings, Record};
use crate::tensor::backend::Backend;
use regex::Regex;
use serde::Deserialize;
/// The main data structure used for deserialization.
///
/// It can hold tree-like structures of nested maps and vectors.
#[derive(Debug, Clone)]
pub enum NestedValue {
/// The default value, which actually does not hold any value and it is used to indicate that
/// the value should be populated with the default value.
Default,
/// A boolean value.
Bool(bool),
/// A string value.
String(String),
/// Floating point 32-bit value.
F32(f32),
/// Floating point 64-bit value.
F64(f64),
/// Signed 16-bit integer value.
I16(i16),
/// Signed 32-bit integer value.
I32(i32),
/// Signed 64-bit integer value.
I64(i64),
/// Unsigned 16-bit integer value used for bf16 and f16 serialization
U16(u16),
/// Unsigned 64-bit integer value.
U64(u64),
/// A map of nested values (typically used for structs)
Map(HashMap<String, NestedValue>),
/// A vector of nested values (typically used for vector of structs or numbers)
Vec(Vec<NestedValue>),
}
impl NestedValue {
/// Get the nested value as a map.
pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {
match self {
NestedValue::Map(map) => Some(map),
_ => None,
}
}
/// Get the nested value as a boolean.
pub fn as_bool(self) -> Option<bool> {
match self {
NestedValue::Bool(bool) => Some(bool),
_ => None,
}
}
/// Get the nested value as a string.
pub fn as_string(self) -> Option<String> {
match self {
NestedValue::String(string) => Some(string),
_ => None,
}
}
/// Get the nested value as a f32.
pub fn as_f32(self) -> Option<f32> {
match self {
NestedValue::F32(f32) => Some(f32),
_ => None,
}
}
/// Get the nested value as a f64.
pub fn as_f64(self) -> Option<f64> {
match self {
NestedValue::F64(f64) => Some(f64),
_ => None,
}
}
/// Get the nested value as an i16.
pub fn as_i16(self) -> Option<i16> {
match self {
NestedValue::I16(i16) => Some(i16),
_ => None,
}
}
/// Get the nested value as an i32.
pub fn as_i32(self) -> Option<i32> {
match self {
NestedValue::I32(i32) => Some(i32),
_ => None,
}
}
/// Get the nested value as an i64.
pub fn as_i64(self) -> Option<i64> {
match self {
NestedValue::I64(i64) => Some(i64),
_ => None,
}
}
/// Get the nested value as a u16.
pub fn as_u16(self) -> Option<u16> {
match self {
NestedValue::U16(u16) => Some(u16),
_ => None,
}
}
/// Get the nested value as a u64.
pub fn as_u64(self) -> Option<u64> {
match self {
NestedValue::U64(u64) => Some(u64),
_ => None,
}
}
/// Deserialize a nested value into a record type.
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
where
B: Backend,
T: Record<B>,
PS: PrecisionSettings,
A: BurnModuleAdapter,
{
let deserializer = Deserializer::<A>::new(self, false);
let item = T::Item::deserialize(deserializer)?;
// Convert the deserialized item into a Record instance
Ok(T::from_item::<PS>(item, device))
}
}
/// Remap the tensor locations according to the key remapping.
///
/// # Arguments
///
/// * `tensors` - A map of tensors.
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
/// for more information.
///
/// # Returns
///
/// A map of tensors with the remapped keys.
pub fn remap<T>(
mut tensors: HashMap<String, T>,
key_remap: Vec<(Regex, String)>,
) -> HashMap<String, T> {
if key_remap.is_empty() {
return tensors;
}
let mut remapped = HashMap::new();
for (name, tensor) in tensors.drain() {
let mut new_name = name.clone();
for (pattern, replacement) in &key_remap {
if pattern.is_match(&name) {
new_name = pattern.replace_all(&name, replacement.as_str()).to_string();
break;
}
}
remapped.insert(new_name, tensor);
}
remapped
}
/// Helper function to insert a value into a nested map/vector of tensors.
fn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) {
if keys.is_empty() {
*current = value;
return;
}
match current {
NestedValue::Map(map) => {
if !map.contains_key(keys[0]) {
let next = if keys[1..]
.first()
.and_then(|k| k.parse::<usize>().ok())
.is_some()
{
NestedValue::Vec(Vec::new())
} else {
NestedValue::Map(HashMap::new())
};
map.insert(keys[0].to_string(), next);
}
insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value);
}
NestedValue::Vec(vec) => {
let index = keys[0].parse::<usize>().unwrap();
if index >= vec.len() {
vec.resize_with(index + 1, || NestedValue::Map(HashMap::new()));
}
insert_nested_value(&mut vec[index], &keys[1..], value);
}
_ => panic!("Invalid structure encountered"),
}
}
/// A trait for encapsulating the serialization logic.
pub trait Serializable {
/// Serializes the object into a `NestedValue` using the provided `Serializer`.
/// This method is generic over the precision settings `PS`.
///
/// # Parameters
/// - `serializer`: The `Serializer` to use for serializing the object.
///
/// # Returns
/// - `Result<NestedValue, Error>`: The result of serialization.
/// Returns a `NestedValue` on success,
/// or an `Error` on failure.
///
/// # Type Parameters
/// - `PS`: The precision settings to use during serialization.
/// This is a generic parameter and can be any type
/// that implements the `PrecisionSettings` trait.
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, Error>
where
PS: PrecisionSettings;
}
/// Convert a vector of tensors to a nested value.
pub fn unflatten<PS, T>(input: HashMap<String, T>) -> Result<NestedValue, Error>
where
PS: PrecisionSettings,
T: Serializable,
{
let mut result = NestedValue::Map(HashMap::new());
for (key, value) in input {
let parts: Vec<&str> = key.split('.').collect();
let st = value.serialize::<PS>(Serializer::new())?;
insert_nested_value(&mut result, &parts, st);
}
Ok(result)
}

View File

@ -0,0 +1,703 @@
use std::collections::HashMap;
use super::adapter::DefaultAdapter;
use super::data::NestedValue;
use super::{adapter::BurnModuleAdapter, error::Error};
use serde::{
de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
forward_to_deserialize_any,
};
const RECORD_ITEM_SUFFIX: &str = "RecordItem";
/// A deserializer for the nested value data structure.
pub struct Deserializer<A: BurnModuleAdapter> {
// This string starts with the input data and characters are truncated off
// the beginning as data is parsed.
value: Option<NestedValue>,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}
impl<A: BurnModuleAdapter> Deserializer<A> {
/// Creates a new deserializer with the given nested value.
///
/// # Arguments
///
/// * `value` - A nested value.
/// * `default_for_missing_fields` - A boolean indicating whether to add missing fields with default value.
pub fn new(value: NestedValue, default_for_missing_fields: bool) -> Self {
Self {
value: Some(value),
default_for_missing_fields,
phantom: std::marker::PhantomData,
}
}
}
impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_any is not implemented")
}
fn deserialize_struct<V>(
self,
name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let value = match self.value {
Some(value) => {
// Adapt modules
if let Some(name) = name.strip_suffix(RECORD_ITEM_SUFFIX) {
A::adapt(name, value)
} else {
value
}
}
None => {
return Err(de::Error::custom(format!(
"Expected some value but got {:?}",
self.value
)))
}
};
match value {
NestedValue::Map(map) => {
// Add missing fields into the map with default value if needed.
let map = if self.default_for_missing_fields {
let mut map = map;
for field in fields.iter().map(|s| s.to_string()) {
map.entry(field).or_insert(NestedValue::Default);
}
map
} else {
map
};
visitor.visit_map(HashMapAccess::<A>::new(
map,
self.default_for_missing_fields,
))
}
_ => Err(de::Error::custom(format!(
"Expected struct but got {:?}",
value
))),
}
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_string(self.value.unwrap().as_string().unwrap().to_string())
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.value {
Some(NestedValue::Map(map)) => visitor.visit_map(HashMapAccess::<A>::new(
map,
self.default_for_missing_fields,
)),
_ => Err(de::Error::custom(format!(
"Expected map value but got {:?}",
self.value
))),
}
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_bool(self.value.unwrap().as_bool().unwrap())
}
fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_i8 is not implemented")
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i16(self.value.unwrap().as_i16().unwrap().to_owned())
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i32(self.value.unwrap().as_i32().unwrap().to_owned())
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i64(self.value.unwrap().as_i64().unwrap().to_owned())
}
fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_u8 is not implemented")
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u16(self.value.unwrap().as_u16().unwrap().to_owned())
}
fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_u32 is not implemented")
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u64(self.value.unwrap().as_u64().unwrap().to_owned())
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_f32(self.value.unwrap().as_f32().unwrap().to_owned())
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_f64(self.value.unwrap().as_f64().unwrap().to_owned())
}
fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_char is not implemented")
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_str(self.value.unwrap().as_string().unwrap().as_ref())
}
fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_bytes is not implemented")
}
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_byte_buf is not implemented")
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if let Some(value) = self.value {
visitor.visit_some(Deserializer::<A>::new(
value,
self.default_for_missing_fields,
))
} else {
visitor.visit_none()
}
}
fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_unit is not implemented")
}
fn deserialize_unit_struct<V>(
self,
_name: &'static str,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_unit_struct is not implemented")
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(Deserializer::<A>::new(
self.value.unwrap(),
self.default_for_missing_fields,
))
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if let Some(NestedValue::Vec(vec)) = self.value {
visitor.visit_seq(VecSeqAccess::<A>::new(vec, self.default_for_missing_fields))
} else {
Err(de::Error::custom(format!(
"Expected Vec but got {:?}",
self.value
)))
}
}
fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_tuple is not implemented")
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_tuple_struct is not implemented")
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_enum is not implemented")
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_identifier is not implemented")
}
}
/// A sequence access for a vector in the nested value data structure.
struct VecSeqAccess<A: BurnModuleAdapter> {
iter: std::vec::IntoIter<NestedValue>,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}
impl<A: BurnModuleAdapter> VecSeqAccess<A> {
fn new(vec: Vec<NestedValue>, default_for_missing_fields: bool) -> Self {
VecSeqAccess {
iter: vec.into_iter(),
default_for_missing_fields,
phantom: std::marker::PhantomData,
}
}
}
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
{
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
let item = match self.iter.next() {
Some(v) => v,
None => return Ok(None),
};
seed.deserialize(
NestedValueWrapper::<A>::new(item, self.default_for_missing_fields).into_deserializer(),
)
.map(Some)
}
}
/// A map access for a map in the nested value data structure.
struct HashMapAccess<A: BurnModuleAdapter> {
iter: std::collections::hash_map::IntoIter<String, NestedValue>,
next_value: Option<NestedValue>,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}
impl<A: BurnModuleAdapter> HashMapAccess<A> {
fn new(map: HashMap<String, NestedValue>, default_for_missing_fields: bool) -> Self {
HashMapAccess {
iter: map.into_iter(),
next_value: None,
default_for_missing_fields,
phantom: std::marker::PhantomData,
}
}
}
impl<'de, A> MapAccess<'de> for HashMapAccess<A>
where
String: IntoDeserializer<'de, Error>,
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
{
type Error = Error;
fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
match self.iter.next() {
Some((k, v)) => {
// Keep the value for the next call to next_value_seed.
self.next_value = Some(v);
// Deserialize the key.
seed.deserialize(k.into_deserializer()).map(Some)
}
None => Ok(None),
}
}
fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
match self.next_value.take() {
Some(NestedValue::Default) => seed.deserialize(DefaultDeserializer),
Some(v) => seed.deserialize(
NestedValueWrapper::new(v, self.default_for_missing_fields).into_deserializer(),
),
None => seed.deserialize(DefaultDeserializer),
}
}
}
/// A wrapper for the nested value data structure with a burn module adapter.
struct NestedValueWrapper<A: BurnModuleAdapter> {
value: NestedValue,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}
impl<A: BurnModuleAdapter> NestedValueWrapper<A> {
fn new(value: NestedValue, default_for_missing_fields: bool) -> Self {
Self {
value,
default_for_missing_fields,
phantom: std::marker::PhantomData,
}
}
}
impl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A> {
type Deserializer = Deserializer<A>;
fn into_deserializer(self) -> Self::Deserializer {
Deserializer::<A>::new(self.value, self.default_for_missing_fields)
}
}
/// A default deserializer that always returns the default value.
struct DefaultDeserializer;
impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!()
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i32(Default::default())
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_f32(Default::default())
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i16(Default::default())
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i64(Default::default())
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u16(Default::default())
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u64(Default::default())
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_f64(Default::default())
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_bool(Default::default())
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_char(Default::default())
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_str(Default::default())
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_i8(Default::default())
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u8(Default::default())
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u32(Default::default())
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_none()
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_seq(DefaultSeqAccess::new(None))
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_string(Default::default())
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let mut map: HashMap<String, NestedValue> = HashMap::new();
for field in _fields.iter().map(|s| s.to_string()) {
map.insert(field, NestedValue::Default);
}
visitor.visit_map(HashMapAccess::<DefaultAdapter>::new(map, true))
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_seq(DefaultSeqAccess::new(Some(len)))
}
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_seq(DefaultSeqAccess::new(Some(len)))
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_map(DefaultMapAccess::new())
}
forward_to_deserialize_any! {
u128 bytes byte_buf unit unit_struct newtype_struct
enum identifier ignored_any
}
}
/// A default sequence access that always returns None (empty sequence).
pub struct DefaultSeqAccess {
size: Option<usize>,
}
impl Default for DefaultSeqAccess {
fn default() -> Self {
Self::new(None)
}
}
impl DefaultSeqAccess {
/// Creates a new default sequence access with the given size hint.
pub fn new(size: Option<usize>) -> Self {
DefaultSeqAccess { size }
}
}
impl<'de> SeqAccess<'de> for DefaultSeqAccess {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
seed.deserialize(DefaultDeserializer).map(Some)
}
fn size_hint(&self) -> Option<usize> {
self.size
}
}
/// A default map access that always returns None (empty map).
pub struct DefaultMapAccess;
impl Default for DefaultMapAccess {
fn default() -> Self {
Self::new()
}
}
impl DefaultMapAccess {
/// Creates a new default map access.
pub fn new() -> Self {
DefaultMapAccess
}
}
impl<'de> MapAccess<'de> for DefaultMapAccess {
type Error = Error;
fn next_key_seed<T>(&mut self, _seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
// Since this is a default implementation, we'll just return None.
Ok(None)
}
fn next_value_seed<T>(&mut self, _seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
unimplemented!("This should never be called since next_key_seed always returns None")
}
fn size_hint(&self) -> Option<usize> {
// Since this is a default implementation, we'll just return None.
None
}
}

View File

@ -0,0 +1,40 @@
use crate::record::RecorderError;
/// The error type for Record serde.
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Failed to deserialize.
#[error("failed to deserialize")]
Deserialize(#[from] serde::de::value::Error),
/// Failed to serialize.
#[error("failed to serialize")]
Serialize(String),
/// Encountered an invalid state.
#[error("invalid state")]
InvalidState,
/// Other error.
#[error("other error: {0}")]
Other(String),
}
impl serde::de::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::Deserialize(serde::de::value::Error::custom(msg.to_string()))
}
}
impl serde::ser::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::Serialize(msg.to_string())
}
}
// Implement From trait for Error to RecorderError
impl From<Error> for RecorderError {
fn from(error: Error) -> Self {
RecorderError::DeserializeError(error.to_string())
}
}

View File

@ -0,0 +1,17 @@
//! Module contains the serde implementation for the record module
//! useful for custom importing model weights, such as PyTorch's pt file format.
/// The adapter trait that is used to convert the nested value to the module type.
pub mod adapter;
/// The main data structure used for deserialization.
pub mod data;
/// The deserializer that is used to convert the nested value to the record.
pub mod ser;
/// The deserializer that is used to convert the nested value to the record.
pub mod de;
/// Error types.
pub mod error;

View File

@ -0,0 +1,375 @@
use std::collections::HashMap;
use super::{
data::NestedValue,
error::{self, Error},
};
use serde::{
ser::{self, SerializeSeq, SerializeStruct, Serializer as SerializerTrait},
Serialize,
};
/// Simple struct serializer that converts a struct into NestedValues.
///
/// NOTE: This is used to serialize Param structs into NestedValues and not so much for
/// the actual serialization of modules (although it could be used for that as well if all
/// primitive types are implemented).
pub struct Serializer {
// The state of the serialization process
state: Option<NestedValue>,
}
impl Serializer {
/// Creates a new serializer.
pub fn new() -> Self {
Serializer { state: None }
}
}
impl Default for Serializer {
fn default() -> Self {
Self::new()
}
}
impl SerializerTrait for Serializer {
type Ok = NestedValue;
type Error = Error;
type SerializeSeq = Self;
type SerializeTuple = ser::Impossible<NestedValue, Self::Error>;
type SerializeTupleStruct = ser::Impossible<NestedValue, Self::Error>;
type SerializeTupleVariant = ser::Impossible<NestedValue, Self::Error>;
type SerializeMap = ser::Impossible<NestedValue, Self::Error>;
type SerializeStruct = Self;
type SerializeStructVariant = ser::Impossible<NestedValue, Self::Error>;
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Self::Error> {
Ok(self)
}
fn serialize_newtype_struct<T: ?Sized>(
self,
_name: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
{
value.serialize(self)
}
fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
Ok(self)
}
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::I32(v))
}
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::String(v.to_string()))
}
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::I16(v))
}
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::I64(v))
}
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::U16(v))
}
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::U64(v))
}
fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::F32(v))
}
fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::F64(v))
}
// The following methods are not implemented because they are not needed for the
// serialization of Param structs.
fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::Default)
}
fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
{
value.serialize(self)
}
fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
) -> Result<Self::Ok, Self::Error> {
unimplemented!()
}
fn serialize_newtype_variant<T: ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
{
unimplemented!()
}
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
unimplemented!()
}
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Self::Error> {
unimplemented!()
}
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant, Self::Error> {
unimplemented!()
}
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
unimplemented!()
}
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
unimplemented!()
}
}
// Implementing the SerializeStruct trait for Serializer
impl SerializeStruct for Serializer {
type Ok = NestedValue;
type Error = Error;
fn serialize_field<T: ?Sized>(
&mut self,
key: &'static str,
value: &T,
) -> Result<(), Self::Error>
where
T: Serialize,
{
let serialized_value = value.serialize(Serializer::new())?;
match self.state {
Some(NestedValue::Map(ref mut map)) => {
map.insert(key.to_string(), serialized_value); // Inserting into the state
}
Some(_) => {
panic!("Invalid state encountered");
}
None => {
let mut map = HashMap::new();
map.insert(key.to_string(), serialized_value); // Inserting into the state
self.state = Some(NestedValue::Map(map));
}
}
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
if self.state.is_none() {
// If the state is empty, return an empty map
Ok(NestedValue::Map(HashMap::new()))
} else {
self.state.ok_or(error::Error::InvalidState)
}
}
}
impl SerializeSeq for Serializer {
type Ok = NestedValue;
type Error = Error;
fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
{
let serialized_value = value.serialize(Serializer::new())?;
match self.state {
Some(NestedValue::Vec(ref mut vec)) => {
vec.push(serialized_value); // Inserting into the state
}
Some(_) => {
panic!("Invalid state encountered");
}
None => {
self.state = Some(NestedValue::Vec(vec![serialized_value]));
}
}
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
if self.state.is_none() {
// If the state is empty, return an empty vector
Ok(NestedValue::Vec(Vec::new()))
} else {
self.state.ok_or(error::Error::InvalidState)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
module::{Param, ParamId},
record::{FullPrecisionSettings, Record},
tensor::Tensor,
};
use serde::Deserialize;
use super::*;
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MyStruct1 {
a: MyStruct3,
b: MyStruct2,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MyStruct2 {
a: i32,
b: Option<i32>,
c: String,
d: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MyStruct3 {
x: String,
y: String,
}
#[test]
fn test_serialize() {
let my_struct = MyStruct1 {
a: MyStruct3 {
x: "Hello".to_owned(),
y: "World".to_owned(),
},
b: MyStruct2 {
a: 1,
b: None,
c: "Hello".to_owned(),
d: Some("World".to_owned()),
},
};
let serialized = my_struct
.serialize(Serializer::new())
.expect("Should serialize item successfully");
let serialized_str = format!("{:?}", serialized);
// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(
serialized_str.len(),
concat!(
r#"Map({"b": Map({"a": I32(1), "c": String("Hello"), "b": Default, "d": String("World")}),"#,
r#" "a": Map({"x": String("Hello"), "y": String("World")})})"#
).len()
);
}
#[test]
fn test_param_serde() {
type Backend = burn_ndarray::NdArray<f32>;
let device = Default::default();
let tensor: Tensor<Backend, 2> = Tensor::ones([2, 2], &device);
let param = Param::new(ParamId::new(), tensor);
let param_item = param.into_item::<FullPrecisionSettings>();
let serialized = param_item
.serialize(Serializer::new())
.expect("Should serialize item successfully");
let serialized_str = format!("{:?}", serialized);
// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(
serialized_str.len(),
concat!(
r#"Map({"id": String("ca893b0b-92cf-4856-a1c2-558191dbb930"), "#,
r#""param": Map({"shape": Vec([U64(2), U64(2)]), "#,
r#""value": Vec([F32(1.0), F32(1.0), F32(1.0), F32(1.0)])})})"#
)
.len()
);
}
}

View File

@ -13,5 +13,5 @@ pub struct TestWithBackendRecord<B: Backend> {
// It compiles
#[derive(Record)]
pub struct TestWithoutBackendRecord {
tensor: usize,
_tensor: usize,
}

View File

@ -28,6 +28,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
/// Field to be serialized.
pub #name: <#ty as burn::record::Record<B>>::Item<S>,
});
bounds.extend(quote! {
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
});

View File

@ -1,7 +1,7 @@
[package]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
description = "Library for importing datamodels into the Burn framework"
edition.workspace = true
@ -9,35 +9,40 @@ license.workspace = true
name = "burn-import"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-import"
version.workspace = true
default-run = "onnx2burn"
[features]
default = ["onnx"]
default = ["onnx", "pytorch"]
onnx = []
pytorch = ["burn/record-item-custom-serde", "thiserror"]
[dependencies]
burn = { path = "../burn", version = "0.12.0" }
burn-ndarray = { path = "../burn-ndarray", version = "0.12.0" }
burn = { path = "../burn", version = "0.12.0", features = ["ndarray"] }
bytemuck = { workspace = true }
candle-core = { workspace = true }
derive-new = { workspace = true }
half = { workspace = true }
log = { workspace = true }
proc-macro2 = { workspace = true }
protobuf = { version = "3.3", features = ["with-bytes"] }
protobuf = { workspace = true, features = ["with-bytes"] }
quote = { workspace = true }
rust-format = { version = "0.3", features = ["token_stream", "post_process"] }
serde = { workspace = true }
regex = { workspace = true }
rust-format = { workspace = true, features = ["token_stream", "post_process"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true, features = ["std"] }
strum = { workspace = true }
strum_macros = { workspace = true }
syn = { workspace = true, features = ["parsing"] }
tracing-subscriber = { workspace = true }
thiserror = { workspace = true, optional = true }
tracing-core = { workspace = true }
tracing-subscriber = { workspace = true }
[build-dependencies]
protobuf-codegen = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
rstest = { workspace = true }

BIN
burn-import/data/mnist.pt Normal file

Binary file not shown.

View File

@ -0,0 +1,17 @@
[package]
name = "pytorch-tests"
version.workspace = true
edition.workspace = true
license.workspace = true
[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
serde = { workspace = true }
float-cmp = { workspace = true }
burn-import = { path = "../", features = ["pytorch"] }
cfg-if = "1.0.0"
[build-dependencies]
burn-import = { path = "../", features = ["pytorch"] }

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.BatchNorm2d(5)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
# Condition batch norm (each forward will affect the running stats)
x1 = torch.ones(1, 5, 2, 2) - 0.5
_ = model(x1)
model.eval() # Set to eval mode to freeze running stats
# Save the model after the first forward
torch.save(model.state_dict(), "batch_norm2d.pt")
x2 = torch.ones(1, 5, 2, 2) - 0.3
print("Input shape: {}", x2.shape)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,60 @@
use burn::{
module::Module,
nn::{BatchNorm, BatchNormConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: BatchNorm<B, 2>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let norm1 = BatchNormConfig::new(4).init_with(record.norm1);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
#[test]
fn batch_norm2d() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/batch_norm/batch_norm2d.pt".into(), &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::ones([1, 5, 2, 2], &device) - 0.3;
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
]],
&device,
);
output.to_data().assert_approx_eq(&expected.to_data(), 5);
}
}

Binary file not shown.

View File

@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([True, False, True])
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "boolean.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,58 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Bool, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Bool>>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}
/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
self.buffer.val()
}
}
#[cfg(test)]
mod tests {
use burn::{
record::{FullPrecisionSettings, Recorder},
tensor::Data,
};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
type Backend = burn_ndarray::NdArray<f32>;
#[test]
#[ignore = "It appears loading boolean tensors are not supported yet"]
// Error skipping: Msg("unsupported storage type BoolStorage")
fn boolean() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/boolean/boolean.pt".into(), &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected =
Tensor::<Backend, 1, Bool>::from_bool(Data::from([true, false, true]), &device);
assert_eq!(output.to_data(), expected.to_data());
}
}

Binary file not shown.

View File

@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.ones(3, 3)
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer + x
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "buffer.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,51 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 2>>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
self.buffer.val() + x
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
#[test]
fn buffer() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/buffer/buffer.pt".into(), &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected = Tensor::<Backend, 2>::ones([3, 3], &device) * 2.0;
output.to_data().assert_approx_eq(&expected.to_data(), 3);
}
}

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.norm = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_blocks = nn.Sequential(
ConvBlock(2, 4, (3, 2)),
ConvBlock(4, 6, (3, 2)),
)
self.norm1 = nn.BatchNorm2d(6)
self.fc1 = nn.Linear(120, 12)
self.fc2 = nn.Linear(12, 10)
def forward(self, x):
x = self.conv_blocks(x)
x = self.norm1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, dim=1)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(2)
model = Net().to(torch.device("cpu"))
# Condition the model (batch norm requires a forward pass to compute the mean and variance)
x1 = torch.ones(1, 2, 9, 6) - 0.1
x2 = torch.ones(1, 2, 9, 6) - 0.3
output = model(x1)
output = model(x2)
model.eval() # set to eval mode
torch.save(model.state_dict(), "complex_nested.pt")
# feed test data
x = torch.ones(1, 2, 9, 6) - 0.5
output = model(x)
print("Input shape: {}", x.shape)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,229 @@
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn::{
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig,
},
tensor::{
activation::{log_softmax, relu},
backend::Backend,
Tensor,
},
};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv2d<B>,
norm: BatchNorm<B, 2>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv_blocks: Vec<ConvBlock<B>>,
norm1: BatchNorm<B, 2>,
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let mut record = record;
let record_block_0 = record.conv_blocks.remove(0);
let record_block_1 = record.conv_blocks.remove(0);
let conv_blocks = vec![
ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init_with(record_block_0.conv),
norm: BatchNormConfig::new(2).init_with(record_block_0.norm),
},
ConvBlock {
conv: Conv2dConfig::new([4, 6], [3, 2]).init_with(record_block_1.conv),
norm: BatchNormConfig::new(4).init_with(record_block_1.norm),
},
];
let norm1 = BatchNormConfig::new(6).init_with(record.norm1);
let fc1 = LinearConfig::new(120, 12).init_with(record.fc1);
let fc2 = LinearConfig::new(12, 10).init_with(record.fc2);
Self {
conv_blocks,
norm1,
fc1,
fc2,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
let x = self.conv_blocks[0].forward(x);
let x = self.conv_blocks[1].forward(x);
let x = self.norm1.forward(x);
let x = x.reshape([0, -1]);
let x = self.fc1.forward(x);
let x = relu(x);
let x = self.fc2.forward(x);
log_softmax(x, 1)
}
}
impl<B: Backend> ConvBlock<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(x);
self.norm.forward(x)
}
}
/// Partial model to test loading of partial records.
#[derive(Module, Debug)]
pub struct PartialNet<B: Backend> {
conv1: ConvBlock<B>,
}
impl<B: Backend> PartialNet<B> {
/// Create a new model from the given record.
pub fn new_with(record: PartialNetRecord<B>) -> Self {
let conv1 = ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init_with(record.conv1.conv),
norm: BatchNormConfig::new(2).init_with(record.conv1.norm),
};
Self { conv1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.conv1.forward(x)
}
}
/// Model with extra fields to test loading of records (e.g. from a different model).
#[derive(Module, Debug)]
pub struct PartialWithExtraNet<B: Backend> {
conv1: ConvBlock<B>,
extra_field: bool, // This field is not present in the pytorch model
}
impl<B: Backend> PartialWithExtraNet<B> {
/// Create a new model from the given record.
pub fn new_with(record: PartialWithExtraNetRecord<B>) -> Self {
let conv1 = ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init_with(record.conv1.conv),
norm: BatchNormConfig::new(2).init_with(record.conv1.norm),
};
Self {
conv1,
extra_field: true,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.conv1.forward(x)
}
}
type TestBackend = burn_ndarray::NdArray<f32>;
fn model_test(record: NetRecord<TestBackend>, precision: usize) {
let device = Default::default();
let model = Net::<TestBackend>::new_with(record);
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
let expected = Tensor::<TestBackend, 2>::from_data(
[[
-2.306_613,
-2.058_945_4,
-2.298_372_7,
-2.358_294,
-2.296_395_5,
-2.416_090_5,
-2.107_669,
-2.428_420_8,
-2.526_469,
-2.319_918_6,
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn full_record() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/complex_nested/complex_nested.pt".into(), &device)
.expect("Should decode state successfully");
model_test(record, 8);
}
#[test]
fn half_record() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/complex_nested/complex_nested.pt".into(), &device)
.expect("Should decode state successfully");
model_test(record, 4);
}
#[test]
fn partial_model_loading() {
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
let load_args = LoadArgs::new("tests/complex_nested/complex_nested.pt".into())
.with_key_remap("conv_blocks\\.0\\.(.*)", "conv1.$1");
let device = Default::default();
// Load the partial record from the full model
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = PartialNet::<TestBackend>::new_with(record);
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
// get the sum of all elements in the output tensor for quick check
let sum = output.sum();
assert_eq!(4.871538, sum.into_scalar());
}
#[test]
fn extra_field_model_loading() {
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
let load_args = LoadArgs::new("tests/complex_nested/complex_nested.pt".into())
.with_key_remap("conv_blocks\\.0\\.(.*)", "conv1.$1");
let device = Default::default();
// Load the partial record from the full model
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = PartialWithExtraNet::<TestBackend>::new_with(record);
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
// get the sum of all elements in the output tensor for quick check
let sum = output.sum();
assert_eq!(4.871538, sum.into_scalar());
assert!(model.extra_field);
}

Binary file not shown.

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv1d(2, 2, 2)
self.conv2 = nn.Conv1d(2, 2, 2, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv1d.pt")
input = torch.rand(1, 2, 6)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,95 @@
use burn::{
module::Module,
nn::conv::{Conv1d, Conv1dConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv1d<B>,
conv2: Conv1d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let conv1 = Conv1dConfig::new(2, 2, 2).init_with(record.conv1);
let conv2 = Conv1dConfig::new(2, 2, 2)
.with_bias(false)
.init_with(record.conv2);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn conv1d(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 3>::from_data(
[[
[
0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965,
],
[
0.33977574,
0.523_940_8,
0.798_063_9,
0.77176833,
0.01122457,
0.80996025,
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 3>::from_data(
[[
[0.02987457, 0.03134188, 0.04234261, -0.02437721],
[-0.03788019, -0.02972012, -0.00806090, -0.01981254],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn conv1d_full_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/conv1d/conv1d.pt".into(), &device)
.expect("Should decode state successfully");
conv1d(record, 7);
}
#[test]
fn conv1d_half_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/conv1d/conv1d.pt".into(), &device)
.expect("Should decode state successfully");
conv1d(record, 4);
}
}

Binary file not shown.

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv2d.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,131 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init_with(record.conv1);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init_with(record.conv2);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn conv2d(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[
[
0.024_595_8,
0.25883394,
0.93905586,
0.416_715_5,
0.713_979_7,
],
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
[
0.99802476,
0.900_794_2,
0.476_588_2,
0.16625845,
0.804_481_1,
],
],
[
[
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
0.943_447_5,
],
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
[
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
0.414_796_4,
],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[-0.02502128, 0.00250649, 0.04841233],
[0.04589614, -0.00296854, 0.01991477],
[0.02920526, 0.059_497_3, 0.04326791],
],
[
[-0.04825336, 0.080_190_9, -0.02375088],
[0.02885434, 0.09638263, -0.07460806],
[0.02004079, 0.06244051, 0.035_887_1],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn conv2d_full_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/conv2d/conv2d.pt".into(), &device)
.expect("Should decode state successfully");
conv2d(record, 7);
}
#[test]
fn conv2d_half_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/conv2d/conv2d.pt".into(), &device)
.expect("Should decode state successfully");
conv2d(record, 4);
}
}

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.ConvTranspose1d(2, 2, 2)
self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv_transpose1d.pt")
input = torch.rand(1, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,81 @@
use burn::{
module::Module,
nn::conv::{ConvTranspose1d, ConvTranspose1dConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: ConvTranspose1d<B>,
conv2: ConvTranspose1d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init_with(record.conv1);
let conv2 = ConvTranspose1dConfig::new([2, 2], 2).init_with(record.conv2);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn conv_transpose1d(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 3>::from_data(
[[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 3>::from_data(
[[
[0.02935525, 0.01119324, -0.01356167, -0.00682688],
[0.01644749, -0.01429807, 0.00083987, 0.00279229],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn conv_transpose1d_full() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/conv_transpose1d/conv_transpose1d.pt".into(), &device)
.expect("Should decode state successfully");
conv_transpose1d(record, 8);
}
#[test]
fn conv_transpose1d_half() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/conv_transpose1d/conv_transpose1d.pt".into(), &device)
.expect("Should decode state successfully");
conv_transpose1d(record, 4);
}
}

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2))
self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv_transpose2d.pt")
input = torch.rand(1, 2, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,94 @@
use burn::{
module::Module,
nn::conv::{ConvTranspose2d, ConvTranspose2dConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: ConvTranspose2d<B>,
conv2: ConvTranspose2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init_with(record.conv1);
let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init_with(record.conv2);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn conv_transpose2d(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
[[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[0.04547675, 0.01879685, -0.01636661, 0.00310803],
[0.02090115, 0.01192738, -0.048_240_2, 0.02252235],
[0.03249975, -0.00460748, 0.05003899, 0.04029131],
[0.02185687, -0.10226749, -0.06508022, -0.01267705],
],
[
[0.00277598, -0.00513832, -0.059_048_3, 0.00567626],
[-0.03149522, -0.195_757_4, 0.03474613, 0.01997269],
[-0.10096474, 0.00679589, 0.041_919_7, -0.02464108],
[-0.03174751, 0.02963913, -0.02703723, -0.01860938],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn conv_transpose2d_full() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/conv_transpose2d/conv_transpose2d.pt".into(), &device)
.expect("Should decode state successfully");
conv_transpose2d(record, 7);
}
#[test]
fn conv_transpose2d_half() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/conv_transpose2d/conv_transpose2d.pt".into(), &device)
.expect("Should decode state successfully");
conv_transpose2d(record, 4);
}
}

Binary file not shown.

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.embed = nn.Embedding(10, 3)
def forward(self, x):
x = self.embed(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "embedding.pt")
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,84 @@
use burn::{
module::Module,
nn::{Embedding, EmbeddingConfig},
tensor::{backend::Backend, Int, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
embed: Embedding<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let embed = EmbeddingConfig::new(10, 3).init_with(record.embed);
Self { embed }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
self.embed.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn embedding(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 2, Int>::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device);
let output = model.forward(input);
let expected = Tensor::<Backend, 3>::from_data(
[
[
[-1.609_484_9, -0.10016718, -0.609_188_9],
[-0.97977227, -1.609_096_3, -0.712_144_6],
[-0.22227049, 1.687_113_4, -0.32062083],
[-0.29934573, 1.879_345_7, -0.07213178],
],
[
[-0.22227049, 1.687_113_4, -0.32062083],
[0.303_722, -0.777_314_3, -0.25145486],
[-0.97977227, -1.609_096_3, -0.712_144_6],
[-0.02878714, 2.357_111, -1.037_338_7],
],
],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn embedding_full_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/embedding/embedding.pt".into(), &device)
.expect("Should decode state successfully");
embedding(record, 3);
}
#[test]
fn embedding_half_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/embedding/embedding.pt".into(), &device)
.expect("Should decode state successfully");
embedding(record, 3);
}
}

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.GroupNorm(2, 6)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "group_norm.pt")
x2 = torch.rand(1, 6, 2, 2)
print("Input shape: {}", x2.shape)
print("Input: {}", x2)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,88 @@
use burn::{
module::Module,
nn::{GroupNorm, GroupNormConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: GroupNorm<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let norm1 = GroupNormConfig::new(2, 6).init_with(record.norm1);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn group_norm(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
[[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]],
[[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]],
[[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]],
[[0.31379688, 0.19801933], [0.41619217, 0.28432965]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]],
[[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]],
[[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]],
[[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]],
[[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]],
[[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn group_norm_full() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/group_norm/group_norm.pt".into(), &device)
.expect("Should decode state successfully");
group_norm(record, 3);
}
#[test]
fn group_norm_half() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/group_norm/group_norm.pt".into(), &device)
.expect("Should decode state successfully");
group_norm(record, 3);
}
}

View File

@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([1, 2, 3])
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "integer.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,69 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Int, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Int>>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}
/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Int> {
self.buffer.val()
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::{
record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder},
tensor::Data,
};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn integer(record: NetRecord<Backend>, _precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected = Tensor::<Backend, 1, Int>::from_data(Data::from([1, 2, 3]), &device);
assert_eq!(output.to_data(), expected.to_data());
}
#[test]
fn integer_full_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/integer/integer.pt".into(), &device)
.expect("Should decode state successfully");
integer(record, 0);
}
#[test]
fn integer_half_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/integer/integer.pt".into(), &device)
.expect("Should decode state successfully");
integer(record, 0);
}
}

View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvModule(nn.Module):
def __init__(self):
super(ConvModule, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = ConvModule()
def forward(self, x):
x = self.conv(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "key_remap.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,116 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init_with(record.conv1);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init_with(record.conv2);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
use super::*;
#[test]
fn key_remap() {
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[
[
0.024_595_8,
0.25883394,
0.93905586,
0.416_715_5,
0.713_979_7,
],
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
[
0.99802476,
0.900_794_2,
0.476_588_2,
0.16625845,
0.804_481_1,
],
],
[
[
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
0.943_447_5,
],
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
[
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
0.414_796_4,
],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[-0.02502128, 0.00250649, 0.04841233],
[0.04589614, -0.00296854, 0.01991477],
[0.02920526, 0.059_497_3, 0.04326791],
],
[
[-0.04825336, 0.080_190_9, -0.02375088],
[0.02885434, 0.09638263, -0.07460806],
[0.02004079, 0.06244051, 0.035_887_1],
],
]],
&device,
);
output.to_data().assert_approx_eq(&expected.to_data(), 7);
}
}

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.LayerNorm(2)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "layer_norm.pt")
x2 = torch.rand(1, 2, 2, 2)
print("Input shape: {}", x2.shape)
print("Input: {}", x2)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,79 @@
use burn::{
module::Module,
nn::{LayerNorm, LayerNormConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: LayerNorm<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let norm1 = LayerNormConfig::new(4).init_with(record.norm1);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn layer_norm(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]],
[[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn layer_norm_full() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/layer_norm/layer_norm.pt".into(), &device)
.expect("Should decode state successfully");
layer_norm(record, 3);
}
#[test]
fn layer_norm_half() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/layer_norm/layer_norm.pt".into(), &device)
.expect("Should decode state successfully");
layer_norm(record, 3);
}
}

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 4, bias=False)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
x = self.fc2(x)
return x
class ModelWithBias(nn.Module):
def __init__(self):
super(ModelWithBias, self).__init__()
self.fc1 = nn.Linear(2, 3)
def forward(self, x):
x = self.fc1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
model_with_bias = ModelWithBias().to(torch.device("cpu"))
torch.save(model.state_dict(), "linear.pt")
torch.save(model_with_bias.state_dict(), "linear_with_bias.pt")
input = torch.rand(1, 2, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
print("Model with bias")
output = model_with_bias(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,149 @@
use burn::{
module::Module,
nn::{Linear, LinearConfig, ReLU},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
relu: ReLU,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
let fc2 = LinearConfig::new(3, 4).init_with(record.fc2);
let relu = ReLU::default();
Self { fc1, fc2, relu }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.fc1.forward(x);
let x = self.relu.forward(x);
self.fc2.forward(x)
}
}
#[derive(Module, Debug)]
struct NetWithBias<B: Backend> {
fc1: Linear<B>,
}
impl<B: Backend> NetWithBias<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetWithBiasRecord<B>) -> Self {
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
Self { fc1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.fc1.forward(x)
}
}
#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
fn linear_test(record: NetRecord<Backend>, precision: usize) {
let device = Default::default();
let model = Net::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[0.09778349, -0.13756673, 0.04962806, 0.08856435],
[0.03163241, -0.02848549, 0.01437942, 0.11905234],
],
[
[0.07628226, -0.10757702, 0.03656857, 0.03824598],
[0.05443089, -0.06904714, 0.02744314, 0.09997337],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected.to_data(), precision);
}
#[test]
fn linear_full_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/linear/linear.pt".into(), &device)
.expect("Should decode state successfully");
linear_test(record, 7);
}
#[test]
fn linear_half_precision() {
let device = Default::default();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
.load("tests/linear/linear.pt".into(), &device)
.expect("Should decode state successfully");
linear_test(record, 4);
}
#[test]
fn linear_with_bias() {
let device = Default::default();
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/linear/linear_with_bias.pt".into(), &device)
.expect("Should decode state successfully");
let model = NetWithBias::<Backend>::new_with(record);
let input = Tensor::<Backend, 4>::from_data(
[[
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<Backend, 4>::from_data(
[[
[
[-0.00432095, -1.107_101_2, 0.870_691_4],
[0.024_595_5, -0.954_462_9, 0.48518157],
],
[
[0.34315687, -0.757_384_2, 0.548_288],
[-0.06608963, -1.072_072_7, 0.645_800_5],
],
]],
&device,
);
output.to_data().assert_approx_eq(&expected.to_data(), 6);
}
}

View File

@ -0,0 +1,20 @@
cfg_if::cfg_if! {
if #[cfg(not(target_os = "windows"))] {
// The crate is not supported on Windows because of Candle's pt bug on Windows
// (see https://github.com/huggingface/candle/issues/1454).
mod batch_norm;
mod boolean;
mod buffer;
mod complex_nested;
mod conv1d;
mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod embedding;
mod group_norm;
mod integer;
mod key_remap;
mod layer_norm;
mod linear;
}
}

View File

@ -1,5 +1,7 @@
#[cfg(feature = "onnx")]
use burn_import::onnx::{ModelGen, RecordType};
#[cfg(feature = "onnx")]
/// Takes an ONNX file and generates a model from it
fn main() {
let onnx_file = std::env::args().nth(1).expect("No input file provided");
@ -15,3 +17,8 @@ fn main() {
.out_dir(output_dir.as_str())
.run_from_cli();
}
#[cfg(not(feature = "onnx"))]
fn main() {
println!("Compiled without Pytorch feature.");
}

View File

@ -50,7 +50,7 @@ pub struct BurnGraph<PS: PrecisionSettings> {
}
// The backend used for recording.
type Backend = burn_ndarray::NdArray;
type Backend = burn::backend::ndarray::NdArray;
impl<PS: PrecisionSettings> BurnGraph<PS> {
/// Register a new operation node into the graph.
@ -400,7 +400,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
pub fn from_embedded(device: &B::Device) -> Self {
let record = BinBytesRecorder::<#precision_ty>::default()
.load(EMBEDDED_STATES.to_vec(), device)
.expect("Failed to decode state");
.expect("Should decode state successfully");
Self::new_with(record)
}

View File

@ -6,8 +6,8 @@ use super::{
max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
use burn::record::PrecisionSettings;
use burn_ndarray::NdArray;
use proc_macro2::TokenStream;
use serde::Serialize;

View File

@ -9,16 +9,25 @@
//! aligns the imported model with Burn's model and converts tensor data into a format compatible with
//! Burn.
#[cfg(any(feature = "pytorch", feature = "onnx"))]
#[macro_use]
extern crate derive_new;
// Enabled when the `pytorch` or `onnx` feature is enabled.
#[cfg(any(feature = "pytorch", feature = "onnx"))]
mod logger;
/// The onnx module.
#[cfg(feature = "onnx")]
pub mod onnx;
/// The module for generating the burn code.
#[cfg(feature = "onnx")]
pub mod burn;
/// The PyTorch module for recorder.
#[cfg(feature = "pytorch")]
pub mod pytorch;
mod formatter;
mod logger;
pub use formatter::*;

View File

@ -1,3 +1,5 @@
#![allow(dead_code)]
use std::error::Error;
use tracing_core::LevelFilter;

View File

@ -0,0 +1,102 @@
use burn::{
module::Param,
record::{PrecisionSettings, Record},
tensor::{backend::Backend, Tensor},
};
use burn::record::serde::{
adapter::{BurnModuleAdapter, DefaultAdapter},
data::NestedValue,
ser::Serializer,
};
use serde::Serialize;
/// A PyTorch adapter for the Burn module used during deserialization.
///
/// Not all Burn module correspond to a PyTorch module. Therefore,
/// we need to adapt the Burn module to a PyTorch module. We implement
/// only those that differ.
pub struct PyTorchAdapter<PS: PrecisionSettings, B: Backend> {
_precision_settings: std::marker::PhantomData<(PS, B)>,
}
impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS, B> {
fn adapt_linear(data: NestedValue) -> NestedValue {
// Get the current module in the form of map.
let mut map = data.as_map().expect("Failed to get map from NestedValue");
// Get/remove the weight parameter.
let weight = map
.remove("weight")
.expect("Failed to find 'weight' key in map");
// Convert the weight parameter to a tensor (use default device, since it's quick operation).
let weight: Param<Tensor<B, 2>> = weight
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
.expect("Failed to deserialize weight");
// Transpose the weight tensor.
let weight_transposed = Param::from(weight.val().transpose());
// Insert the transposed weight tensor back into the map.
map.insert(
"weight".to_owned(),
serialize::<PS, _, 2>(weight_transposed),
);
// Return the modified map.
NestedValue::Map(map)
}
fn adapt_group_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
}
/// Helper function to serialize a param tensor.
fn serialize<PS, B, const D: usize>(val: Param<Tensor<B, D>>) -> NestedValue
where
B: Backend,
PS: PrecisionSettings,
{
let serializer = Serializer::new();
val.into_item::<PS>()
.serialize(serializer)
.expect("Failed to serialize the item")
}
/// Helper function to rename the weight and bias parameters to gamma and beta.
///
/// This is needed because PyTorch uses different names for the normalizer parameter
/// than Burn. Burn uses gamma and beta, while PyTorch uses weight and bias.
fn rename_weight_bias(data: NestedValue) -> NestedValue {
// Get the current module in the form of map.
let mut map = data.as_map().expect("Failed to get map from NestedValue");
// Rename the weight parameter to gamma.
let weight = map
.remove("weight")
.expect("Failed to find 'weight' key in map");
map.insert("gamma".to_owned(), weight);
// Rename the bias parameter to beta.
let bias = map
.remove("bias")
.expect("Failed to find 'bias' key in map");
map.insert("beta".to_owned(), bias);
// Return the modified map.
NestedValue::Map(map)
}

View File

@ -0,0 +1,21 @@
use burn::record::{serde::error, RecorderError};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Serde error: {0}")]
Serde(#[from] error::Error),
#[error("Candle pickle error: {0}")]
CandlePickle(#[from] candle_core::Error),
// Add other kinds of errors as needed
#[error("other error: {0}")]
Other(String),
}
// Implement From trait for Error to RecorderError
impl From<Error> for RecorderError {
fn from(error: Error) -> Self {
RecorderError::DeserializeError(error.to_string())
}
}

View File

@ -0,0 +1,6 @@
mod adapter;
mod error;
mod reader;
mod recorder;
pub use recorder::{LoadArgs, PyTorchFileRecorder};

View File

@ -0,0 +1,129 @@
use core::ops::Deref;
use std::collections::HashMap;
use std::path::Path;
use super::{adapter::PyTorchAdapter, error::Error};
use burn::{
module::ParamId,
record::{ParamSerde, PrecisionSettings},
tensor::{DataSerialize, Element, ElementConversion},
};
use burn::{
record::serde::{
data::{remap, unflatten, NestedValue, Serializable},
de::Deserializer,
error,
ser::Serializer,
},
tensor::backend::Backend,
};
use candle_core::{pickle, WithDType};
use half::{bf16, f16};
use regex::Regex;
use serde::{de::DeserializeOwned, Serialize};
/// Deserializes a PyTorch file.
///
/// # Arguments
///
/// * `path` - A string slice that holds the path of the file to read.
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
pub fn from_file<PS, D, B>(path: &Path, key_remap: Vec<(Regex, String)>) -> Result<D, Error>
where
D: DeserializeOwned,
PS: PrecisionSettings,
B: Backend,
{
// Read the pickle file and return a vector of Candle tensors
let tensors: HashMap<String, CandleTensor> = pickle::read_all(path)?
.into_iter()
.map(|(key, tensor)| (key, CandleTensor(tensor)))
.collect();
// Remap the keys (replace the keys in the map with the new keys)
let tensors = remap(tensors, key_remap);
// Convert the vector of Candle tensors to a nested value data structure
let nested_value = unflatten::<PS, _>(tensors)?;
// Create a deserializer with PyTorch adapter and nested value
let deserializer = Deserializer::<PyTorchAdapter<PS, B>>::new(nested_value, true);
// Deserialize the nested value into a record type
let value = D::deserialize(deserializer)?;
Ok(value)
}
/// Serializes a candle tensor.
///
/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `DataSerialize` struct.
///
/// Values are serialized as `FloatElem` or `IntElem` depending on the precision settings.
impl Serializable for CandleTensor {
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, error::Error>
where
PS: PrecisionSettings,
{
let shape = self.shape().clone().into_dims();
let flatten = CandleTensor(self.flatten_all().expect("Failed to flatten the tensor"));
let param_id = ParamId::new().into_string();
match self.dtype() {
candle_core::DType::U8 => {
serialize_data::<u8, PS::IntElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::U32 => {
serialize_data::<u32, PS::IntElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::I64 => {
serialize_data::<i64, PS::IntElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::BF16 => {
serialize_data::<bf16, PS::FloatElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::F16 => {
serialize_data::<f16, PS::FloatElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::F32 => {
serialize_data::<f32, PS::FloatElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::F64 => {
serialize_data::<f64, PS::FloatElem>(flatten, shape, param_id, serializer)
}
}
}
}
/// Helper function to serialize a candle tensor data.
fn serialize_data<T, E>(
tensor: CandleTensor,
shape: Vec<usize>,
param_id: String,
serializer: Serializer,
) -> Result<NestedValue, error::Error>
where
E: Element + Serialize,
T: WithDType + ElementConversion,
{
let data: Vec<E> = tensor
.to_vec1::<T>()
.map_err(|err| error::Error::Other(format!("Candle to vec1 error: {err}")))?
.into_iter()
.map(ElementConversion::elem)
.collect();
ParamSerde::new(param_id, DataSerialize::new(data, shape)).serialize(serializer)
}
/// New type struct for Candle tensors because we need to implement the `Serializable` trait for it.
struct CandleTensor(candle_core::Tensor);
impl Deref for CandleTensor {
type Target = candle_core::Tensor;
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@ -0,0 +1,136 @@
use core::marker::PhantomData;
use std::path::PathBuf;
use burn::{
record::{PrecisionSettings, Record, Recorder, RecorderError},
tensor::backend::Backend,
};
use regex::Regex;
use serde::{de::DeserializeOwned, Serialize};
use super::reader::from_file;
/// A recorder that that loads PyTorch files (`.pt`) into Burn modules.
///
/// LoadArgs can be used to remap keys or file path.
/// See [LoadArgs](struct.LoadArgs.html) for more information.
///
#[derive(new, Debug, Default, Clone)]
pub struct PyTorchFileRecorder<PS: PrecisionSettings> {
_settings: PhantomData<PS>,
}
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS> {
type Settings = PS;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = LoadArgs;
fn save_item<I: Serialize>(
&self,
_item: I,
_file: Self::RecordArgs,
) -> Result<(), RecorderError> {
unimplemented!("save_item not implemented for PyTorchFileRecorder")
}
fn load_item<I: DeserializeOwned>(&self, _file: Self::LoadArgs) -> Result<I, RecorderError> {
unimplemented!("load_item not implemented for PyTorchFileRecorder")
}
fn load<R: Record<B>>(
&self,
args: Self::LoadArgs,
device: &B::Device,
) -> Result<R, RecorderError> {
let item = from_file::<PS, R::Item<Self::Settings>, B>(&args.file, args.key_remap)?;
Ok(R::from_item(item, device))
}
}
/// Arguments for loading a PyTorch file.
///
/// # Fields
///
/// * `file` - The path to the file to load.
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
/// for more information.
///
/// # Notes
///
/// Use [Netron](https://github.com/lutzroeder/netron) to inspect the keys of the PyTorch file (.pt extension).
///
///
/// # Examples
///
/// ```text
/// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
/// use burn::record::FullPrecisionSettings;
/// use burn::record::Recorder;
///
/// let args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
/// .with_key_remap("conv\\.(.*)", "$1"); // // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
///
/// let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
/// .load(args)
/// .expect("Should decode state successfully");
/// ```
#[derive(Debug, Clone)]
pub struct LoadArgs {
/// The path to the file to load.
pub file: PathBuf,
/// A list of key remappings.
pub key_remap: Vec<(Regex, String)>,
}
impl LoadArgs {
/// Create a new `LoadArgs` instance.
///
/// # Arguments
///
/// * `file` - The path to the file to load.
pub fn new(file: PathBuf) -> Self {
Self {
file,
key_remap: Vec::new(),
}
}
/// Set key remapping.
///
/// # Arguments
///
/// * `pattern` - The Regex pattern to be replaced.
/// * `replacement` - The pattern to replace with.
///
/// See [Regex](https://docs.rs/regex/1.5.4/regex/#syntax) for the pattern syntax and
/// [Replacement](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) for the
/// replacement syntax.
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
let regex = Regex::new(&format!("^{}$", pattern)).unwrap();
self.key_remap.push((regex, replacement.into()));
self
}
}
impl From<PathBuf> for LoadArgs {
fn from(val: PathBuf) -> Self {
LoadArgs::new(val)
}
}
impl From<String> for LoadArgs {
fn from(val: String) -> Self {
LoadArgs::new(val.into())
}
}
impl From<&str> for LoadArgs {
fn from(val: &str) -> Self {
LoadArgs::new(val.into())
}
}

View File

@ -53,6 +53,9 @@ candle = ["burn-core/candle"]
# Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
# Records
record-item-custom-serde = ["burn-core/record-item-custom-serde"]
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **

View File

@ -0,0 +1,24 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
edition = "2021"
license = "MIT OR Apache-2.0"
name = "pytorch-import"
publish = false
version = "0.12.0"
[dependencies]
burn = { path = "../../burn", features = [
"ndarray",
"dataset",
"sqlite-bundled",
] }
model = { path = "./model" }
[build-dependencies]
model = { path = "./model" }
burn = { path = "../../burn", features = ["ndarray"] }
burn-import = { path = "../../burn-import", features = [
"pytorch",
], default-features = false }

View File

@ -0,0 +1,29 @@
# Import PyTorch Weights
This crate provides a simple example for importing PyTorch generated weights to Burn.
The `.pt` file is converted into a Burn consumable file (message pack format) using `burn-import`.
The conversation is done in the `build.rs` file.
The model is separated into a sub-crate because `build.rs` needs for conversion and build.rs cannot
import modules for the same crate.
## Usage
```bash
cargo run -- 15
```
Output:
```bash
Finished dev [unoptimized + debuginfo] target(s) in 0.13s
Running `burn/target/debug/onnx-inference 15`
Image index: 15
Success!
Predicted: 5
Actual: 5
See the image online, click the link below:
https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/15/image/image.jpg
```

View File

@ -0,0 +1,41 @@
/// This build script does the following:
/// 1. Loads PyTorch weights into a model record.
/// 2. Saves the model record to a file using the `NamedMpkFileRecorder`.
use std::path::Path;
use burn::backend::NdArray;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
// Basic backend type (not used directly here).
type B = NdArray<f32>;
fn main() {
if cfg!(target_os = "windows") {
println!(
"{}",
"cargo:warning=The crate is not supported on Windows because of ".to_owned()
+ "Candle's pt bug on Windows "
+ "(see https://github.com/huggingface/candle/issues/1454)."
);
std::process::exit(0);
}
let device = Default::default();
// Load PyTorch weights into a model record.
let record: model::ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("pytorch/mnist.pt".into(), &device)
.expect("Failed to decode state");
// Save the model record to a file.
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
// Save into the OUT_DIR directory so that the model can be loaded by the
let out_dir = std::env::var("OUT_DIR").unwrap();
let file_path = Path::new(&out_dir).join("model/mnist");
recorder
.record(record, file_path)
.expect("Failed to save model record");
}

View File

@ -0,0 +1,11 @@
[package]
name = "model"
version = "0.1.0"
edition = "2021"
[dependencies]
burn = { path = "../../../burn" }
burn-import = { path = "../../../burn-import", features = [
"pytorch",
], default-features = false }

View File

@ -0,0 +1,81 @@
use std::env;
use std::path::Path;
use burn::nn::conv::Conv2d;
use burn::nn::conv::Conv2dConfig;
use burn::nn::BatchNorm;
use burn::nn::BatchNormConfig;
use burn::nn::Linear;
use burn::nn::LinearConfig;
use burn::record::FullPrecisionSettings;
use burn::record::NamedMpkFileRecorder;
use burn::record::Recorder;
use burn::tensor::activation::log_softmax;
use burn::tensor::activation::relu;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
conv3: Conv2d<B>,
norm1: BatchNorm<B, 2>,
fc1: Linear<B>,
fc2: Linear<B>,
norm2: BatchNorm<B, 0>,
phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
let out_dir = env::var_os("OUT_DIR").unwrap();
let file_path = Path::new(&out_dir).join("model/mnist");
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
.load(file_path, &B::Device::default())
.expect("Failed to decode state");
Self::new_with(record)
}
}
impl<B: Backend> Model<B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
let conv1 = Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1);
let conv2 = Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2);
let conv3 = Conv2dConfig::new([16, 24], [3, 3]).init_with(record.conv3);
let norm1 = BatchNormConfig::new(24).init_with(record.norm1);
let fc1 = LinearConfig::new(11616, 32).init_with(record.fc1);
let fc2 = LinearConfig::new(32, 10).init_with(record.fc2);
let norm2 = BatchNormConfig::new(10).init_with(record.norm2);
Self {
conv1,
conv2,
conv3,
norm1,
fc1,
fc2,
norm2,
phantom: core::marker::PhantomData,
}
}
pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 2> {
let conv1_out1 = self.conv1.forward(input1);
let relu1_out1 = relu(conv1_out1);
let conv2_out1 = self.conv2.forward(relu1_out1);
let relu2_out1 = relu(conv2_out1);
let conv3_out1 = self.conv3.forward(relu2_out1);
let relu3_out1 = relu(conv3_out1);
let norm1_out1 = self.norm1.forward(relu3_out1);
let flatten1_out1 = norm1_out1.flatten(1, 3);
let fc1_out1 = self.fc1.forward(flatten1_out1);
let relu4_out1 = relu(fc1_out1);
let fc2_out1 = self.fc2.forward(relu4_out1);
let norm2_out1 = self.norm2.forward(fc2_out1);
log_softmax(norm2_out1, 1)
}
}

Binary file not shown.

View File

@ -0,0 +1,163 @@
#!/usr/bin/env python3
# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py
# under the following license: BSD-3-Clause license
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3)
self.conv2 = nn.Conv2d(8, 16, 3)
self.conv3 = nn.Conv2d(16, 24, 3)
self.norm1 = nn.BatchNorm2d(24)
self.dropout1 = nn.Dropout(0.3)
self.fc1 = nn.Linear(24 * 22 * 22, 32)
self.fc2 = nn.Linear(32, 10)
self.norm2 = nn.BatchNorm1d(10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = self.norm1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.norm2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=8, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=True,
help='For Saving the current Model')
parser.add_argument('--export-onnx', action='store_true', default=False,
help='For Saving the current Model in ONNX format')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
torch.manual_seed(args.seed)
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('/tmp/mnist-data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('/tmp/mnist-data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist.pt")
if args.export_onnx:
dummy_input = torch.randn(1, 1, 28, 28, device=device)
torch.onnx.export(model, dummy_input, "mnist.onnx",
verbose=True, opset_version=16)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,74 @@
use std::env::args;
use std::path::Path;
use burn::backend::ndarray::NdArray;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn::tensor::Tensor;
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::data::dataset::Dataset;
use model::Model;
const IMAGE_INX: usize = 42; // <- Change this to test a different image
// Build output direct that contains converted model weight file path
const OUT_DIR: &str = concat!(env!("OUT_DIR"), "/model/mnist");
fn main() {
// Get image index argument (first) from command line
let image_index = if let Some(image_index) = args().nth(1) {
println!("Image index: {}", image_index);
image_index
.parse::<usize>()
.expect("Failed to parse image index")
} else {
println!("No image index provided; Using default image index: {IMAGE_INX}");
IMAGE_INX
};
assert!(image_index < 10000, "Image index must be less than 10000");
type Backend = NdArray<f32>;
let device = Default::default();
// Load the model record from converted PyTorch file by the build script
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
.load(Path::new(OUT_DIR).into(), &device)
.expect("Failed to decode state");
// Create a new model and load the state
let model: Model<Backend> = Model::new_with(record);
// Load the MNIST dataset and get an item
let dataset = MNISTDataset::test();
let item = dataset.get(image_index).unwrap();
// Create a tensor from the image data
let image_data = item.image.iter().copied().flatten().collect::<Vec<f32>>();
let mut input: Tensor<Backend, 4> =
Tensor::from_floats(image_data.as_slice(), &device).reshape([1, 1, 28, 28]);
// Normalize the input
input = ((input / 255) - 0.1307) / 0.3081;
// Run the model on the input
let output = model.forward(input);
// Get the index of the maximum value
let arg_max = output.argmax(1).into_scalar() as usize;
// Check if the index matches the label
assert!(arg_max == item.label);
println!("Success!");
println!("Predicted: {}", arg_max);
println!("Actual: {}", item.label);
// Print the image URL if the image index is less than 100 (the online dataset only has 100 images)
if image_index < 100 {
println!("See the image online, click the link below:");
println!("https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/{image_index}/image/image.jpg");
}
}

View File

@ -246,9 +246,18 @@ fn no_std_checks() {
// Test burn-core with tch and wgpu backend
fn burn_core_std() {
// Run cargo test --features test-tch
group!("Test: burn-core (tch)");
cargo_test(["-p", "burn-core", "--features", "test-tch"].into());
// Run cargo test --features test-tch, record-item-custom-serde
group!("Test: burn-core (tch) and record-item-custom-serde");
cargo_test(
[
"-p",
"burn-core",
"--features",
"test-tch",
"record-item-custom-serde",
]
.into(),
);
endgroup!();
// Run cargo test --features test-wgpu