mirror of https://github.com/tracel-ai/burn.git
Add support for loading PyTorch `.pt` (weights/states) files directly to model's record (#1085)
This commit is contained in:
parent
4ca3e31601
commit
0368409eb3
20
Cargo.toml
20
Cargo.toml
|
@ -16,6 +16,7 @@ members = [
|
|||
"burn-derive",
|
||||
"burn-import",
|
||||
"burn-import/onnx-tests",
|
||||
"burn-import/pytorch-tests",
|
||||
"burn-ndarray",
|
||||
"burn-no-std-tests",
|
||||
"burn-tch",
|
||||
|
@ -26,6 +27,7 @@ members = [
|
|||
"burn-train",
|
||||
"xtask",
|
||||
"examples/*",
|
||||
"examples/pytorch-import/model",
|
||||
"backend-comparison",
|
||||
]
|
||||
|
||||
|
@ -40,6 +42,9 @@ license = "MIT OR Apache-2.0"
|
|||
[workspace.dependencies]
|
||||
async-trait = "0.1.74"
|
||||
bytemuck = "1.14"
|
||||
candle-core = { version = "0.3.2" }
|
||||
clap = "4.4.11"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
const-random = "0.1.17"
|
||||
csv = "1.3.0"
|
||||
dashmap = "5.5.3"
|
||||
|
@ -56,14 +61,18 @@ libm = "0.2.8"
|
|||
log = { default-features = false, version = "0.4.20" }
|
||||
pretty_assertions = "1.4"
|
||||
proc-macro2 = "1.0.69"
|
||||
protobuf = "3.3"
|
||||
protobuf-codegen = "3.3"
|
||||
quote = "1.0.33"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = { version = "0.23.0" }
|
||||
rayon = "1.8.0"
|
||||
regex = "1.10.2"
|
||||
reqwest = "0.11.23"
|
||||
rmp-serde = "1.1.2"
|
||||
rstest = "0.18.2"
|
||||
rusqlite = { version = "0.30.0" }
|
||||
rust-format = { version = "0.3.4" }
|
||||
sanitize-filename = "0.5.0"
|
||||
serde_rusqlite = "0.34.0"
|
||||
serde-wasm-bindgen = "0.6.1"
|
||||
|
@ -80,8 +89,6 @@ wasm-bindgen = "0.2.88"
|
|||
wasm-bindgen-futures = "0.4.38"
|
||||
wasm-logger = "0.2.0"
|
||||
wasm-timer = "0.2.5"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
reqwest = "0.11.23"
|
||||
|
||||
|
||||
# WGPU stuff
|
||||
|
@ -90,14 +97,15 @@ pollster = "0.3"
|
|||
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
|
||||
wgpu = "0.19.0"
|
||||
|
||||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
bincode = { version = "2.0.0-rc.3", features = [
|
||||
"alloc",
|
||||
"serde",
|
||||
], default-features = false }
|
||||
derive-new = { version = "0.5.9", default-features = false }
|
||||
|
||||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
derive-new = { version = "0.6.0", default-features = false }
|
||||
|
||||
half = { version = "2.3.1", features = [
|
||||
"alloc",
|
||||
|
|
|
@ -21,8 +21,8 @@ accelerate = ["candle-core/accelerate"]
|
|||
derive-new = { workspace = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.12.0", default-features = false }
|
||||
half = { workspace = true }
|
||||
candle-core = { workspace = true }
|
||||
|
||||
candle-core = { version = "0.3.2" }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.12.0", default-features = false, features = [
|
||||
|
|
|
@ -62,6 +62,9 @@ tch = ["burn-tch"]
|
|||
candle = ["burn-candle"]
|
||||
wgpu = ["burn-wgpu"]
|
||||
|
||||
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
|
||||
record-item-custom-serde = ["thiserror", "regex"]
|
||||
|
||||
# Serialization formats
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
|
||||
|
@ -105,6 +108,8 @@ bincode = { workspace = true }
|
|||
half = { workspace = true }
|
||||
rmp-serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
|
||||
thiserror = { workspace = true, optional = true }
|
||||
regex = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
|
|
@ -11,7 +11,7 @@ use burn_tensor::{
|
|||
use core::marker::PhantomData;
|
||||
|
||||
/// Record used for constant type implementing the [module](crate::module::Module) trait.
|
||||
#[derive(Debug, Clone, Copy, new)]
|
||||
#[derive(Debug, Clone, Copy, new, Default)]
|
||||
pub struct ConstantRecord;
|
||||
|
||||
impl serde::Serialize for ConstantRecord {
|
||||
|
|
|
@ -200,7 +200,7 @@ mod tests {
|
|||
module
|
||||
.gamma
|
||||
.as_ref()
|
||||
.expect("Gamma is None")
|
||||
.expect("gamma should not be None")
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::ones([6].into()), 3);
|
||||
|
@ -208,7 +208,7 @@ mod tests {
|
|||
module
|
||||
.beta
|
||||
.as_ref()
|
||||
.expect("beta is None")
|
||||
.expect("beta should not be None")
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::zeros([6]), 3);
|
||||
|
|
|
@ -109,6 +109,11 @@ macro_rules! str2writer {
|
|||
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
|
||||
let path = $file.as_path();
|
||||
|
||||
// Add parent directories if they don't exist
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).ok();
|
||||
}
|
||||
|
||||
if path.exists() {
|
||||
log::info!("File exists, replacing");
|
||||
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
|
|
@ -17,3 +17,6 @@ mod file;
|
|||
pub use file::*;
|
||||
|
||||
pub use primitive::ParamSerde;
|
||||
|
||||
#[cfg(feature = "record-item-custom-serde")]
|
||||
pub mod serde;
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
use alloc::string::String;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Bool;
|
||||
use burn_tensor::Int;
|
||||
use burn_tensor::Tensor;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use alloc::{
|
||||
string::{String, ToString},
|
||||
vec,
|
||||
vec::Vec,
|
||||
};
|
||||
use core::{fmt, marker::PhantomData};
|
||||
|
||||
use super::tensor::BoolTensorSerde;
|
||||
use super::tensor::FloatTensorSerde;
|
||||
use super::tensor::IntTensorSerde;
|
||||
use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
|
||||
use super::{PrecisionSettings, Record};
|
||||
use crate::module::{Param, ParamId};
|
||||
use burn_tensor::{DataSerialize, Element};
|
||||
|
||||
use burn_tensor::{backend::Backend, Bool, DataSerialize, Element, Int, Tensor};
|
||||
|
||||
use hashbrown::HashMap;
|
||||
use serde::{
|
||||
de::{Error, SeqAccess, Visitor},
|
||||
ser::SerializeTuple,
|
||||
Deserialize, Serialize,
|
||||
};
|
||||
|
||||
impl<B> Record<B> for ()
|
||||
where
|
||||
|
@ -63,21 +65,22 @@ where
|
|||
|
||||
impl<const N: usize, T, B> Record<B> for [T; N]
|
||||
where
|
||||
T: Record<B> + core::fmt::Debug,
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
|
||||
/// The record item is an array of the record item of the elements.
|
||||
/// The reason why we wrap the array in a struct is because serde does not support
|
||||
/// deserializing arrays of variable size,
|
||||
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
|
||||
/// for backward compatibility reasons. Serde APIs were created before const generics.
|
||||
type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.map(Record::into_item).into_iter().collect()
|
||||
Array(self.map(Record::into_item))
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.into_iter()
|
||||
.map(|i| Record::from_item(i, device))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap_or_else(|_| panic!("An arrar of size {N}"))
|
||||
item.0.map(|i| Record::from_item(i, device))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -223,3 +226,80 @@ primitive!(i64);
|
|||
primitive!(i32);
|
||||
primitive!(i16);
|
||||
primitive!(i8);
|
||||
|
||||
/// A wrapper around an array of size N, so that it can be serialized and deserialized
|
||||
/// using serde.
|
||||
///
|
||||
/// The reason why we wrap the array in a struct is because serde does not support
|
||||
/// deserializing arrays of variable size,
|
||||
/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
|
||||
/// for backward compatibility reasons. Serde APIs were created before const generics.
|
||||
pub struct Array<const N: usize, T>([T; N]);
|
||||
|
||||
impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut seq = serializer.serialize_tuple(self.0.len())?;
|
||||
for element in &self.0 {
|
||||
seq.serialize_element(element)?;
|
||||
}
|
||||
seq.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct ArrayVisitor<T, const N: usize> {
|
||||
marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
type Value = Array<N, T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a fixed size array")
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: SeqAccess<'de>,
|
||||
{
|
||||
let mut items = vec![];
|
||||
|
||||
for i in 0..N {
|
||||
let item = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(i, &self))?;
|
||||
items.push(item);
|
||||
}
|
||||
|
||||
let array: [T; N] = items
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.map_err(|_| "An array of size {N}")
|
||||
.unwrap();
|
||||
|
||||
Ok(Array(array))
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_tuple(
|
||||
N,
|
||||
ArrayVisitor {
|
||||
marker: PhantomData,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,6 +158,9 @@ pub enum RecorderError {
|
|||
/// File not found.
|
||||
FileNotFound(String),
|
||||
|
||||
/// Failed to read file.
|
||||
DeserializeError(String),
|
||||
|
||||
/// Other error.
|
||||
Unknown(String),
|
||||
}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
use super::data::NestedValue;
|
||||
|
||||
/// A trait that defines the adapter for a Burn module.
|
||||
///
|
||||
/// This is used to adapt an incoming module to a Burn module.
|
||||
pub trait BurnModuleAdapter: Sized {
|
||||
/// Adapts a module.
|
||||
fn adapt(name: &str, data: NestedValue) -> NestedValue {
|
||||
match name {
|
||||
"BatchNorm" => Self::adapt_batch_norm(data),
|
||||
"Conv1d" => Self::adapt_conv1d(data),
|
||||
"Conv2d" => Self::adapt_conv2d(data),
|
||||
"ConvTranspose1d" => Self::adapt_conv_transpose_1d(data),
|
||||
"ConvTranspose2d" => Self::adapt_conv_transpose_2d(data),
|
||||
"Embedding" => Self::adapt_embedding(data),
|
||||
"GroupNorm" => Self::adapt_group_norm(data),
|
||||
"LayerNorm" => Self::adapt_layer_norm(data),
|
||||
"Linear" => Self::adapt_linear(data),
|
||||
_ => data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapts a linear module.
|
||||
fn adapt_linear(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts a Convolution 1D module.
|
||||
fn adapt_conv1d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts a Convolution 2D module.
|
||||
fn adapt_conv2d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts convolution transpose 1D module.
|
||||
fn adapt_conv_transpose_1d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts convolution transpose 2D module.
|
||||
fn adapt_conv_transpose_2d(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts embedding module.
|
||||
fn adapt_embedding(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts group normalization module.
|
||||
fn adapt_group_norm(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts layer normalization module.
|
||||
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
|
||||
/// Adapts batch normalization module.
|
||||
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
|
||||
data
|
||||
}
|
||||
}
|
||||
|
||||
/// Default adapter that takes no action.
|
||||
pub struct DefaultAdapter;
|
||||
impl BurnModuleAdapter for DefaultAdapter {}
|
|
@ -0,0 +1,261 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::adapter::BurnModuleAdapter;
|
||||
use super::de::Deserializer;
|
||||
use super::error::Error;
|
||||
use super::ser::Serializer;
|
||||
use crate::record::{PrecisionSettings, Record};
|
||||
use crate::tensor::backend::Backend;
|
||||
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
|
||||
/// The main data structure used for deserialization.
|
||||
///
|
||||
/// It can hold tree-like structures of nested maps and vectors.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum NestedValue {
|
||||
/// The default value, which actually does not hold any value and it is used to indicate that
|
||||
/// the value should be populated with the default value.
|
||||
Default,
|
||||
|
||||
/// A boolean value.
|
||||
Bool(bool),
|
||||
|
||||
/// A string value.
|
||||
String(String),
|
||||
|
||||
/// Floating point 32-bit value.
|
||||
F32(f32),
|
||||
|
||||
/// Floating point 64-bit value.
|
||||
F64(f64),
|
||||
|
||||
/// Signed 16-bit integer value.
|
||||
I16(i16),
|
||||
|
||||
/// Signed 32-bit integer value.
|
||||
I32(i32),
|
||||
|
||||
/// Signed 64-bit integer value.
|
||||
I64(i64),
|
||||
|
||||
/// Unsigned 16-bit integer value used for bf16 and f16 serialization
|
||||
U16(u16),
|
||||
|
||||
/// Unsigned 64-bit integer value.
|
||||
U64(u64),
|
||||
|
||||
/// A map of nested values (typically used for structs)
|
||||
Map(HashMap<String, NestedValue>),
|
||||
|
||||
/// A vector of nested values (typically used for vector of structs or numbers)
|
||||
Vec(Vec<NestedValue>),
|
||||
}
|
||||
impl NestedValue {
|
||||
/// Get the nested value as a map.
|
||||
pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {
|
||||
match self {
|
||||
NestedValue::Map(map) => Some(map),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a boolean.
|
||||
pub fn as_bool(self) -> Option<bool> {
|
||||
match self {
|
||||
NestedValue::Bool(bool) => Some(bool),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a string.
|
||||
pub fn as_string(self) -> Option<String> {
|
||||
match self {
|
||||
NestedValue::String(string) => Some(string),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a f32.
|
||||
pub fn as_f32(self) -> Option<f32> {
|
||||
match self {
|
||||
NestedValue::F32(f32) => Some(f32),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a f64.
|
||||
pub fn as_f64(self) -> Option<f64> {
|
||||
match self {
|
||||
NestedValue::F64(f64) => Some(f64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as an i16.
|
||||
pub fn as_i16(self) -> Option<i16> {
|
||||
match self {
|
||||
NestedValue::I16(i16) => Some(i16),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as an i32.
|
||||
pub fn as_i32(self) -> Option<i32> {
|
||||
match self {
|
||||
NestedValue::I32(i32) => Some(i32),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as an i64.
|
||||
pub fn as_i64(self) -> Option<i64> {
|
||||
match self {
|
||||
NestedValue::I64(i64) => Some(i64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a u16.
|
||||
pub fn as_u16(self) -> Option<u16> {
|
||||
match self {
|
||||
NestedValue::U16(u16) => Some(u16),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a u64.
|
||||
pub fn as_u64(self) -> Option<u64> {
|
||||
match self {
|
||||
NestedValue::U64(u64) => Some(u64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize a nested value into a record type.
|
||||
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
|
||||
where
|
||||
B: Backend,
|
||||
T: Record<B>,
|
||||
PS: PrecisionSettings,
|
||||
A: BurnModuleAdapter,
|
||||
{
|
||||
let deserializer = Deserializer::<A>::new(self, false);
|
||||
|
||||
let item = T::Item::deserialize(deserializer)?;
|
||||
|
||||
// Convert the deserialized item into a Record instance
|
||||
Ok(T::from_item::<PS>(item, device))
|
||||
}
|
||||
}
|
||||
|
||||
/// Remap the tensor locations according to the key remapping.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - A map of tensors.
|
||||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
||||
/// for more information.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A map of tensors with the remapped keys.
|
||||
pub fn remap<T>(
|
||||
mut tensors: HashMap<String, T>,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
) -> HashMap<String, T> {
|
||||
if key_remap.is_empty() {
|
||||
return tensors;
|
||||
}
|
||||
|
||||
let mut remapped = HashMap::new();
|
||||
|
||||
for (name, tensor) in tensors.drain() {
|
||||
let mut new_name = name.clone();
|
||||
for (pattern, replacement) in &key_remap {
|
||||
if pattern.is_match(&name) {
|
||||
new_name = pattern.replace_all(&name, replacement.as_str()).to_string();
|
||||
break;
|
||||
}
|
||||
}
|
||||
remapped.insert(new_name, tensor);
|
||||
}
|
||||
|
||||
remapped
|
||||
}
|
||||
|
||||
/// Helper function to insert a value into a nested map/vector of tensors.
|
||||
fn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) {
|
||||
if keys.is_empty() {
|
||||
*current = value;
|
||||
return;
|
||||
}
|
||||
|
||||
match current {
|
||||
NestedValue::Map(map) => {
|
||||
if !map.contains_key(keys[0]) {
|
||||
let next = if keys[1..]
|
||||
.first()
|
||||
.and_then(|k| k.parse::<usize>().ok())
|
||||
.is_some()
|
||||
{
|
||||
NestedValue::Vec(Vec::new())
|
||||
} else {
|
||||
NestedValue::Map(HashMap::new())
|
||||
};
|
||||
map.insert(keys[0].to_string(), next);
|
||||
}
|
||||
insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value);
|
||||
}
|
||||
NestedValue::Vec(vec) => {
|
||||
let index = keys[0].parse::<usize>().unwrap();
|
||||
if index >= vec.len() {
|
||||
vec.resize_with(index + 1, || NestedValue::Map(HashMap::new()));
|
||||
}
|
||||
insert_nested_value(&mut vec[index], &keys[1..], value);
|
||||
}
|
||||
_ => panic!("Invalid structure encountered"),
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for encapsulating the serialization logic.
|
||||
pub trait Serializable {
|
||||
/// Serializes the object into a `NestedValue` using the provided `Serializer`.
|
||||
/// This method is generic over the precision settings `PS`.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `serializer`: The `Serializer` to use for serializing the object.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Result<NestedValue, Error>`: The result of serialization.
|
||||
/// Returns a `NestedValue` on success,
|
||||
/// or an `Error` on failure.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `PS`: The precision settings to use during serialization.
|
||||
/// This is a generic parameter and can be any type
|
||||
/// that implements the `PrecisionSettings` trait.
|
||||
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, Error>
|
||||
where
|
||||
PS: PrecisionSettings;
|
||||
}
|
||||
|
||||
/// Convert a vector of tensors to a nested value.
|
||||
pub fn unflatten<PS, T>(input: HashMap<String, T>) -> Result<NestedValue, Error>
|
||||
where
|
||||
PS: PrecisionSettings,
|
||||
T: Serializable,
|
||||
{
|
||||
let mut result = NestedValue::Map(HashMap::new());
|
||||
|
||||
for (key, value) in input {
|
||||
let parts: Vec<&str> = key.split('.').collect();
|
||||
let st = value.serialize::<PS>(Serializer::new())?;
|
||||
|
||||
insert_nested_value(&mut result, &parts, st);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
|
@ -0,0 +1,703 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::adapter::DefaultAdapter;
|
||||
use super::data::NestedValue;
|
||||
use super::{adapter::BurnModuleAdapter, error::Error};
|
||||
|
||||
use serde::{
|
||||
de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
|
||||
forward_to_deserialize_any,
|
||||
};
|
||||
|
||||
const RECORD_ITEM_SUFFIX: &str = "RecordItem";
|
||||
|
||||
/// A deserializer for the nested value data structure.
|
||||
pub struct Deserializer<A: BurnModuleAdapter> {
|
||||
// This string starts with the input data and characters are truncated off
|
||||
// the beginning as data is parsed.
|
||||
value: Option<NestedValue>,
|
||||
default_for_missing_fields: bool,
|
||||
phantom: std::marker::PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<A: BurnModuleAdapter> Deserializer<A> {
|
||||
/// Creates a new deserializer with the given nested value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - A nested value.
|
||||
/// * `default_for_missing_fields` - A boolean indicating whether to add missing fields with default value.
|
||||
pub fn new(value: NestedValue, default_for_missing_fields: bool) -> Self {
|
||||
Self {
|
||||
value: Some(value),
|
||||
default_for_missing_fields,
|
||||
phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
|
||||
type Error = Error;
|
||||
|
||||
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_any is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_struct<V>(
|
||||
self,
|
||||
name: &'static str,
|
||||
fields: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
let value = match self.value {
|
||||
Some(value) => {
|
||||
// Adapt modules
|
||||
if let Some(name) = name.strip_suffix(RECORD_ITEM_SUFFIX) {
|
||||
A::adapt(name, value)
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(de::Error::custom(format!(
|
||||
"Expected some value but got {:?}",
|
||||
self.value
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match value {
|
||||
NestedValue::Map(map) => {
|
||||
// Add missing fields into the map with default value if needed.
|
||||
let map = if self.default_for_missing_fields {
|
||||
let mut map = map;
|
||||
for field in fields.iter().map(|s| s.to_string()) {
|
||||
map.entry(field).or_insert(NestedValue::Default);
|
||||
}
|
||||
map
|
||||
} else {
|
||||
map
|
||||
};
|
||||
|
||||
visitor.visit_map(HashMapAccess::<A>::new(
|
||||
map,
|
||||
self.default_for_missing_fields,
|
||||
))
|
||||
}
|
||||
|
||||
_ => Err(de::Error::custom(format!(
|
||||
"Expected struct but got {:?}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_string(self.value.unwrap().as_string().unwrap().to_string())
|
||||
}
|
||||
|
||||
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_unit()
|
||||
}
|
||||
|
||||
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
match self.value {
|
||||
Some(NestedValue::Map(map)) => visitor.visit_map(HashMapAccess::<A>::new(
|
||||
map,
|
||||
self.default_for_missing_fields,
|
||||
)),
|
||||
|
||||
_ => Err(de::Error::custom(format!(
|
||||
"Expected map value but got {:?}",
|
||||
self.value
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_bool(self.value.unwrap().as_bool().unwrap())
|
||||
}
|
||||
|
||||
fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_i8 is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i16(self.value.unwrap().as_i16().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i32(self.value.unwrap().as_i32().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i64(self.value.unwrap().as_i64().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_u8 is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_u16(self.value.unwrap().as_u16().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_u32 is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_u64(self.value.unwrap().as_u64().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_f32(self.value.unwrap().as_f32().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_f64(self.value.unwrap().as_f64().unwrap().to_owned())
|
||||
}
|
||||
|
||||
fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_char is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_str(self.value.unwrap().as_string().unwrap().as_ref())
|
||||
}
|
||||
|
||||
fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_bytes is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_byte_buf is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if let Some(value) = self.value {
|
||||
visitor.visit_some(Deserializer::<A>::new(
|
||||
value,
|
||||
self.default_for_missing_fields,
|
||||
))
|
||||
} else {
|
||||
visitor.visit_none()
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_unit is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_unit_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_unit_struct is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_newtype_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_newtype_struct(Deserializer::<A>::new(
|
||||
self.value.unwrap(),
|
||||
self.default_for_missing_fields,
|
||||
))
|
||||
}
|
||||
|
||||
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if let Some(NestedValue::Vec(vec)) = self.value {
|
||||
visitor.visit_seq(VecSeqAccess::<A>::new(vec, self.default_for_missing_fields))
|
||||
} else {
|
||||
Err(de::Error::custom(format!(
|
||||
"Expected Vec but got {:?}",
|
||||
self.value
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_tuple is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_tuple_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_len: usize,
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_tuple_struct is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_enum<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variants: &'static [&'static str],
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_enum is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_identifier is not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
/// A sequence access for a vector in the nested value data structure.
|
||||
struct VecSeqAccess<A: BurnModuleAdapter> {
|
||||
iter: std::vec::IntoIter<NestedValue>,
|
||||
default_for_missing_fields: bool,
|
||||
phantom: std::marker::PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<A: BurnModuleAdapter> VecSeqAccess<A> {
|
||||
fn new(vec: Vec<NestedValue>, default_for_missing_fields: bool) -> Self {
|
||||
VecSeqAccess {
|
||||
iter: vec.into_iter(),
|
||||
default_for_missing_fields,
|
||||
phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A>
|
||||
where
|
||||
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
|
||||
A: BurnModuleAdapter,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
let item = match self.iter.next() {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
seed.deserialize(
|
||||
NestedValueWrapper::<A>::new(item, self.default_for_missing_fields).into_deserializer(),
|
||||
)
|
||||
.map(Some)
|
||||
}
|
||||
}
|
||||
|
||||
/// A map access for a map in the nested value data structure.
|
||||
struct HashMapAccess<A: BurnModuleAdapter> {
|
||||
iter: std::collections::hash_map::IntoIter<String, NestedValue>,
|
||||
next_value: Option<NestedValue>,
|
||||
default_for_missing_fields: bool,
|
||||
phantom: std::marker::PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<A: BurnModuleAdapter> HashMapAccess<A> {
|
||||
fn new(map: HashMap<String, NestedValue>, default_for_missing_fields: bool) -> Self {
|
||||
HashMapAccess {
|
||||
iter: map.into_iter(),
|
||||
next_value: None,
|
||||
default_for_missing_fields,
|
||||
phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, A> MapAccess<'de> for HashMapAccess<A>
|
||||
where
|
||||
String: IntoDeserializer<'de, Error>,
|
||||
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
|
||||
A: BurnModuleAdapter,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
match self.iter.next() {
|
||||
Some((k, v)) => {
|
||||
// Keep the value for the next call to next_value_seed.
|
||||
self.next_value = Some(v);
|
||||
// Deserialize the key.
|
||||
seed.deserialize(k.into_deserializer()).map(Some)
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
match self.next_value.take() {
|
||||
Some(NestedValue::Default) => seed.deserialize(DefaultDeserializer),
|
||||
Some(v) => seed.deserialize(
|
||||
NestedValueWrapper::new(v, self.default_for_missing_fields).into_deserializer(),
|
||||
),
|
||||
None => seed.deserialize(DefaultDeserializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for the nested value data structure with a burn module adapter.
|
||||
struct NestedValueWrapper<A: BurnModuleAdapter> {
|
||||
value: NestedValue,
|
||||
default_for_missing_fields: bool,
|
||||
phantom: std::marker::PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<A: BurnModuleAdapter> NestedValueWrapper<A> {
|
||||
fn new(value: NestedValue, default_for_missing_fields: bool) -> Self {
|
||||
Self {
|
||||
value,
|
||||
default_for_missing_fields,
|
||||
phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A> {
|
||||
type Deserializer = Deserializer<A>;
|
||||
|
||||
fn into_deserializer(self) -> Self::Deserializer {
|
||||
Deserializer::<A>::new(self.value, self.default_for_missing_fields)
|
||||
}
|
||||
}
|
||||
|
||||
/// A default deserializer that always returns the default value.
|
||||
struct DefaultDeserializer;
|
||||
|
||||
impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
|
||||
type Error = Error;
|
||||
|
||||
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i32(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_f32(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i16(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i64(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_u16(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_u64(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_f64(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_bool(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_char(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_str(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_i8(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_u8(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_u32(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_none()
|
||||
}
|
||||
|
||||
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_seq(DefaultSeqAccess::new(None))
|
||||
}
|
||||
|
||||
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_string(Default::default())
|
||||
}
|
||||
|
||||
fn deserialize_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_fields: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
let mut map: HashMap<String, NestedValue> = HashMap::new();
|
||||
|
||||
for field in _fields.iter().map(|s| s.to_string()) {
|
||||
map.insert(field, NestedValue::Default);
|
||||
}
|
||||
|
||||
visitor.visit_map(HashMapAccess::<DefaultAdapter>::new(map, true))
|
||||
}
|
||||
|
||||
fn deserialize_tuple_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
len: usize,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_seq(DefaultSeqAccess::new(Some(len)))
|
||||
}
|
||||
|
||||
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_seq(DefaultSeqAccess::new(Some(len)))
|
||||
}
|
||||
|
||||
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_map(DefaultMapAccess::new())
|
||||
}
|
||||
|
||||
forward_to_deserialize_any! {
|
||||
u128 bytes byte_buf unit unit_struct newtype_struct
|
||||
enum identifier ignored_any
|
||||
}
|
||||
}
|
||||
|
||||
/// A default sequence access that always returns None (empty sequence).
|
||||
pub struct DefaultSeqAccess {
|
||||
size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for DefaultSeqAccess {
|
||||
fn default() -> Self {
|
||||
Self::new(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl DefaultSeqAccess {
|
||||
/// Creates a new default sequence access with the given size hint.
|
||||
pub fn new(size: Option<usize>) -> Self {
|
||||
DefaultSeqAccess { size }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> SeqAccess<'de> for DefaultSeqAccess {
|
||||
type Error = Error;
|
||||
|
||||
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
seed.deserialize(DefaultDeserializer).map(Some)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> Option<usize> {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
/// A default map access that always returns None (empty map).
|
||||
pub struct DefaultMapAccess;
|
||||
|
||||
impl Default for DefaultMapAccess {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl DefaultMapAccess {
|
||||
/// Creates a new default map access.
|
||||
pub fn new() -> Self {
|
||||
DefaultMapAccess
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> MapAccess<'de> for DefaultMapAccess {
|
||||
type Error = Error;
|
||||
|
||||
fn next_key_seed<T>(&mut self, _seed: T) -> Result<Option<T::Value>, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
// Since this is a default implementation, we'll just return None.
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn next_value_seed<T>(&mut self, _seed: T) -> Result<T::Value, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
unimplemented!("This should never be called since next_key_seed always returns None")
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> Option<usize> {
|
||||
// Since this is a default implementation, we'll just return None.
|
||||
None
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
use crate::record::RecorderError;
|
||||
|
||||
/// The error type for Record serde.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Failed to deserialize.
|
||||
#[error("failed to deserialize")]
|
||||
Deserialize(#[from] serde::de::value::Error),
|
||||
|
||||
/// Failed to serialize.
|
||||
#[error("failed to serialize")]
|
||||
Serialize(String),
|
||||
|
||||
/// Encountered an invalid state.
|
||||
#[error("invalid state")]
|
||||
InvalidState,
|
||||
|
||||
/// Other error.
|
||||
#[error("other error: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl serde::de::Error for Error {
|
||||
fn custom<T: std::fmt::Display>(msg: T) -> Self {
|
||||
Error::Deserialize(serde::de::value::Error::custom(msg.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::ser::Error for Error {
|
||||
fn custom<T: std::fmt::Display>(msg: T) -> Self {
|
||||
Error::Serialize(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Implement From trait for Error to RecorderError
|
||||
impl From<Error> for RecorderError {
|
||||
fn from(error: Error) -> Self {
|
||||
RecorderError::DeserializeError(error.to_string())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
//! Module contains the serde implementation for the record module
|
||||
//! useful for custom importing model weights, such as PyTorch's pt file format.
|
||||
|
||||
/// The adapter trait that is used to convert the nested value to the module type.
|
||||
pub mod adapter;
|
||||
|
||||
/// The main data structure used for deserialization.
|
||||
pub mod data;
|
||||
|
||||
/// The deserializer that is used to convert the nested value to the record.
|
||||
pub mod ser;
|
||||
|
||||
/// The deserializer that is used to convert the nested value to the record.
|
||||
pub mod de;
|
||||
|
||||
/// Error types.
|
||||
pub mod error;
|
|
@ -0,0 +1,375 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::{
|
||||
data::NestedValue,
|
||||
error::{self, Error},
|
||||
};
|
||||
|
||||
use serde::{
|
||||
ser::{self, SerializeSeq, SerializeStruct, Serializer as SerializerTrait},
|
||||
Serialize,
|
||||
};
|
||||
|
||||
/// Simple struct serializer that converts a struct into NestedValues.
|
||||
///
|
||||
/// NOTE: This is used to serialize Param structs into NestedValues and not so much for
|
||||
/// the actual serialization of modules (although it could be used for that as well if all
|
||||
/// primitive types are implemented).
|
||||
pub struct Serializer {
|
||||
// The state of the serialization process
|
||||
state: Option<NestedValue>,
|
||||
}
|
||||
|
||||
impl Serializer {
|
||||
/// Creates a new serializer.
|
||||
pub fn new() -> Self {
|
||||
Serializer { state: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Serializer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializerTrait for Serializer {
|
||||
type Ok = NestedValue;
|
||||
type Error = Error;
|
||||
type SerializeSeq = Self;
|
||||
type SerializeTuple = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeTupleStruct = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeTupleVariant = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeMap = ser::Impossible<NestedValue, Self::Error>;
|
||||
type SerializeStruct = Self;
|
||||
type SerializeStructVariant = ser::Impossible<NestedValue, Self::Error>;
|
||||
|
||||
fn serialize_struct(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeStruct, Self::Error> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn serialize_newtype_struct<T: ?Sized>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
value: &T,
|
||||
) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
value.serialize(self)
|
||||
}
|
||||
|
||||
fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::I32(v))
|
||||
}
|
||||
|
||||
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::String(v.to_string()))
|
||||
}
|
||||
|
||||
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::I16(v))
|
||||
}
|
||||
|
||||
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::I64(v))
|
||||
}
|
||||
|
||||
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U16(v))
|
||||
}
|
||||
|
||||
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U64(v))
|
||||
}
|
||||
|
||||
fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::F32(v))
|
||||
}
|
||||
|
||||
fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::F64(v))
|
||||
}
|
||||
|
||||
// The following methods are not implemented because they are not needed for the
|
||||
// serialization of Param structs.
|
||||
|
||||
fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::Default)
|
||||
}
|
||||
fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
value.serialize(self)
|
||||
}
|
||||
|
||||
fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_unit_variant(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_newtype_variant<T: ?Sized>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
_value: &T,
|
||||
) -> Result<Self::Ok, Self::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_tuple_struct(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeTupleStruct, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_tuple_variant(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeTupleVariant, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_struct_variant(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variant_index: u32,
|
||||
_variant: &'static str,
|
||||
_len: usize,
|
||||
) -> Result<Self::SerializeStructVariant, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
// Implementing the SerializeStruct trait for Serializer
|
||||
impl SerializeStruct for Serializer {
|
||||
type Ok = NestedValue;
|
||||
type Error = Error;
|
||||
|
||||
fn serialize_field<T: ?Sized>(
|
||||
&mut self,
|
||||
key: &'static str,
|
||||
value: &T,
|
||||
) -> Result<(), Self::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let serialized_value = value.serialize(Serializer::new())?;
|
||||
|
||||
match self.state {
|
||||
Some(NestedValue::Map(ref mut map)) => {
|
||||
map.insert(key.to_string(), serialized_value); // Inserting into the state
|
||||
}
|
||||
Some(_) => {
|
||||
panic!("Invalid state encountered");
|
||||
}
|
||||
None => {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(key.to_string(), serialized_value); // Inserting into the state
|
||||
self.state = Some(NestedValue::Map(map));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
if self.state.is_none() {
|
||||
// If the state is empty, return an empty map
|
||||
Ok(NestedValue::Map(HashMap::new()))
|
||||
} else {
|
||||
self.state.ok_or(error::Error::InvalidState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeSeq for Serializer {
|
||||
type Ok = NestedValue;
|
||||
type Error = Error;
|
||||
|
||||
fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let serialized_value = value.serialize(Serializer::new())?;
|
||||
|
||||
match self.state {
|
||||
Some(NestedValue::Vec(ref mut vec)) => {
|
||||
vec.push(serialized_value); // Inserting into the state
|
||||
}
|
||||
Some(_) => {
|
||||
panic!("Invalid state encountered");
|
||||
}
|
||||
None => {
|
||||
self.state = Some(NestedValue::Vec(vec![serialized_value]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn end(self) -> Result<Self::Ok, Self::Error> {
|
||||
if self.state.is_none() {
|
||||
// If the state is empty, return an empty vector
|
||||
Ok(NestedValue::Vec(Vec::new()))
|
||||
} else {
|
||||
self.state.ok_or(error::Error::InvalidState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
module::{Param, ParamId},
|
||||
record::{FullPrecisionSettings, Record},
|
||||
tensor::Tensor,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct MyStruct1 {
|
||||
a: MyStruct3,
|
||||
b: MyStruct2,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct MyStruct2 {
|
||||
a: i32,
|
||||
b: Option<i32>,
|
||||
c: String,
|
||||
d: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct MyStruct3 {
|
||||
x: String,
|
||||
y: String,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize() {
|
||||
let my_struct = MyStruct1 {
|
||||
a: MyStruct3 {
|
||||
x: "Hello".to_owned(),
|
||||
y: "World".to_owned(),
|
||||
},
|
||||
b: MyStruct2 {
|
||||
a: 1,
|
||||
b: None,
|
||||
c: "Hello".to_owned(),
|
||||
d: Some("World".to_owned()),
|
||||
},
|
||||
};
|
||||
|
||||
let serialized = my_struct
|
||||
.serialize(Serializer::new())
|
||||
.expect("Should serialize item successfully");
|
||||
|
||||
let serialized_str = format!("{:?}", serialized);
|
||||
|
||||
// Compare the lengths of expected and actual serialized strings because
|
||||
// the order of the fields is not guaranteed for HashMaps.
|
||||
assert_eq!(
|
||||
serialized_str.len(),
|
||||
concat!(
|
||||
r#"Map({"b": Map({"a": I32(1), "c": String("Hello"), "b": Default, "d": String("World")}),"#,
|
||||
r#" "a": Map({"x": String("Hello"), "y": String("World")})})"#
|
||||
).len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_param_serde() {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let tensor: Tensor<Backend, 2> = Tensor::ones([2, 2], &device);
|
||||
|
||||
let param = Param::new(ParamId::new(), tensor);
|
||||
|
||||
let param_item = param.into_item::<FullPrecisionSettings>();
|
||||
|
||||
let serialized = param_item
|
||||
.serialize(Serializer::new())
|
||||
.expect("Should serialize item successfully");
|
||||
|
||||
let serialized_str = format!("{:?}", serialized);
|
||||
|
||||
// Compare the lengths of expected and actual serialized strings because
|
||||
// the order of the fields is not guaranteed for HashMaps.
|
||||
assert_eq!(
|
||||
serialized_str.len(),
|
||||
concat!(
|
||||
r#"Map({"id": String("ca893b0b-92cf-4856-a1c2-558191dbb930"), "#,
|
||||
r#""param": Map({"shape": Vec([U64(2), U64(2)]), "#,
|
||||
r#""value": Vec([F32(1.0), F32(1.0), F32(1.0), F32(1.0)])})})"#
|
||||
)
|
||||
.len()
|
||||
);
|
||||
}
|
||||
}
|
|
@ -13,5 +13,5 @@ pub struct TestWithBackendRecord<B: Backend> {
|
|||
// It compiles
|
||||
#[derive(Record)]
|
||||
pub struct TestWithoutBackendRecord {
|
||||
tensor: usize,
|
||||
_tensor: usize,
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
/// Field to be serialized.
|
||||
pub #name: <#ty as burn::record::Record<B>>::Item<S>,
|
||||
});
|
||||
|
||||
bounds.extend(quote! {
|
||||
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
});
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
authors = [
|
||||
"Dilshod Tadjibaev (@antimora)",
|
||||
"Nathaniel Simard (@nathanielsimard)",
|
||||
"Dilshod Tadjibaev (@antimora)",
|
||||
"Nathaniel Simard (@nathanielsimard)",
|
||||
]
|
||||
description = "Library for importing datamodels into the Burn framework"
|
||||
edition.workspace = true
|
||||
|
@ -9,35 +9,40 @@ license.workspace = true
|
|||
name = "burn-import"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/burn-import"
|
||||
|
||||
version.workspace = true
|
||||
|
||||
default-run = "onnx2burn"
|
||||
|
||||
[features]
|
||||
default = ["onnx"]
|
||||
default = ["onnx", "pytorch"]
|
||||
onnx = []
|
||||
pytorch = ["burn/record-item-custom-serde", "thiserror"]
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../burn", version = "0.12.0" }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.12.0" }
|
||||
burn = { path = "../burn", version = "0.12.0", features = ["ndarray"] }
|
||||
|
||||
bytemuck = { workspace = true }
|
||||
candle-core = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
half = { workspace = true }
|
||||
log = { workspace = true }
|
||||
proc-macro2 = { workspace = true }
|
||||
protobuf = { version = "3.3", features = ["with-bytes"] }
|
||||
protobuf = { workspace = true, features = ["with-bytes"] }
|
||||
quote = { workspace = true }
|
||||
rust-format = { version = "0.3", features = ["token_stream", "post_process"] }
|
||||
serde = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
rust-format = { workspace = true, features = ["token_stream", "post_process"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true, features = ["std"] }
|
||||
strum = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
syn = { workspace = true, features = ["parsing"] }
|
||||
tracing-subscriber = { workspace = true }
|
||||
thiserror = { workspace = true, optional = true }
|
||||
tracing-core = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
protobuf-codegen = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "pytorch-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
burn-ndarray = { path = "../../burn-ndarray" }
|
||||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
burn-import = { path = "../", features = ["pytorch"] }
|
||||
cfg-if = "1.0.0"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
burn-import = { path = "../", features = ["pytorch"] }
|
|
@ -0,0 +1 @@
|
|||
|
Binary file not shown.
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm1 = nn.BatchNorm2d(5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
# Condition batch norm (each forward will affect the running stats)
|
||||
x1 = torch.ones(1, 5, 2, 2) - 0.5
|
||||
_ = model(x1)
|
||||
model.eval() # Set to eval mode to freeze running stats
|
||||
# Save the model after the first forward
|
||||
torch.save(model.state_dict(), "batch_norm2d.pt")
|
||||
|
||||
x2 = torch.ones(1, 5, 2, 2) - 0.3
|
||||
print("Input shape: {}", x2.shape)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,60 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{BatchNorm, BatchNormConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
norm1: BatchNorm<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let norm1 = BatchNormConfig::new(4).init_with(record.norm1);
|
||||
Self { norm1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.norm1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn batch_norm2d() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/batch_norm/batch_norm2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::ones([1, 5, 2, 2], &device) - 0.3;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected.to_data(), 5);
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
buffer = torch.tensor([True, False, True])
|
||||
self.register_buffer("buffer", buffer, persistent=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.buffer
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "boolean.pt")
|
||||
|
||||
input = torch.ones(3, 3)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,58 @@
|
|||
use burn::{
|
||||
module::{Module, Param},
|
||||
tensor::{backend::Backend, Bool, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
buffer: Param<Tensor<B, 1, Bool>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
Self {
|
||||
buffer: record.buffer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
|
||||
self.buffer.val()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use burn::{
|
||||
record::{FullPrecisionSettings, Recorder},
|
||||
tensor::Data,
|
||||
};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[test]
|
||||
#[ignore = "It appears loading boolean tensors are not supported yet"]
|
||||
// Error skipping: Msg("unsupported storage type BoolStorage")
|
||||
fn boolean() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/boolean/boolean.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 2>::ones([3, 3], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected =
|
||||
Tensor::<Backend, 1, Bool>::from_bool(Data::from([true, false, true]), &device);
|
||||
|
||||
assert_eq!(output.to_data(), expected.to_data());
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
buffer = torch.ones(3, 3)
|
||||
self.register_buffer("buffer", buffer, persistent=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.buffer + x
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "buffer.pt")
|
||||
|
||||
input = torch.ones(3, 3)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,51 @@
|
|||
use burn::{
|
||||
module::{Module, Param},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
buffer: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
Self {
|
||||
buffer: record.buffer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.buffer.val() + x
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn buffer() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/buffer/buffer.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 2>::ones([3, 3], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 2>::ones([3, 3], &device) * 2.0;
|
||||
|
||||
output.to_data().assert_approx_eq(&expected.to_data(), 3);
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
self.conv_blocks = nn.Sequential(
|
||||
ConvBlock(2, 4, (3, 2)),
|
||||
ConvBlock(4, 6, (3, 2)),
|
||||
)
|
||||
self.norm1 = nn.BatchNorm2d(6)
|
||||
|
||||
self.fc1 = nn.Linear(120, 12)
|
||||
self.fc2 = nn.Linear(12, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_blocks(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = F.log_softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(2)
|
||||
|
||||
|
||||
model = Net().to(torch.device("cpu"))
|
||||
|
||||
# Condition the model (batch norm requires a forward pass to compute the mean and variance)
|
||||
x1 = torch.ones(1, 2, 9, 6) - 0.1
|
||||
x2 = torch.ones(1, 2, 9, 6) - 0.3
|
||||
output = model(x1)
|
||||
output = model(x2)
|
||||
model.eval() # set to eval mode
|
||||
|
||||
torch.save(model.state_dict(), "complex_nested.pt")
|
||||
|
||||
# feed test data
|
||||
x = torch.ones(1, 2, 9, 6) - 0.5
|
||||
output = model(x)
|
||||
print("Input shape: {}", x.shape)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,229 @@
|
|||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
BatchNorm, BatchNormConfig, Linear, LinearConfig,
|
||||
},
|
||||
tensor::{
|
||||
activation::{log_softmax, relu},
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
norm: BatchNorm<B, 2>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv_blocks: Vec<ConvBlock<B>>,
|
||||
norm1: BatchNorm<B, 2>,
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let mut record = record;
|
||||
let record_block_0 = record.conv_blocks.remove(0);
|
||||
let record_block_1 = record.conv_blocks.remove(0);
|
||||
|
||||
let conv_blocks = vec![
|
||||
ConvBlock {
|
||||
conv: Conv2dConfig::new([2, 4], [3, 2]).init_with(record_block_0.conv),
|
||||
norm: BatchNormConfig::new(2).init_with(record_block_0.norm),
|
||||
},
|
||||
ConvBlock {
|
||||
conv: Conv2dConfig::new([4, 6], [3, 2]).init_with(record_block_1.conv),
|
||||
norm: BatchNormConfig::new(4).init_with(record_block_1.norm),
|
||||
},
|
||||
];
|
||||
let norm1 = BatchNormConfig::new(6).init_with(record.norm1);
|
||||
let fc1 = LinearConfig::new(120, 12).init_with(record.fc1);
|
||||
let fc2 = LinearConfig::new(12, 10).init_with(record.fc2);
|
||||
Self {
|
||||
conv_blocks,
|
||||
norm1,
|
||||
fc1,
|
||||
fc2,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
|
||||
let x = self.conv_blocks[0].forward(x);
|
||||
let x = self.conv_blocks[1].forward(x);
|
||||
let x = self.norm1.forward(x);
|
||||
let x = x.reshape([0, -1]);
|
||||
let x = self.fc1.forward(x);
|
||||
let x = relu(x);
|
||||
let x = self.fc2.forward(x);
|
||||
|
||||
log_softmax(x, 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv.forward(x);
|
||||
|
||||
self.norm.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Partial model to test loading of partial records.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PartialNet<B: Backend> {
|
||||
conv1: ConvBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> PartialNet<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: PartialNetRecord<B>) -> Self {
|
||||
let conv1 = ConvBlock {
|
||||
conv: Conv2dConfig::new([2, 4], [3, 2]).init_with(record.conv1.conv),
|
||||
norm: BatchNormConfig::new(2).init_with(record.conv1.norm),
|
||||
};
|
||||
Self { conv1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.conv1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Model with extra fields to test loading of records (e.g. from a different model).
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PartialWithExtraNet<B: Backend> {
|
||||
conv1: ConvBlock<B>,
|
||||
extra_field: bool, // This field is not present in the pytorch model
|
||||
}
|
||||
|
||||
impl<B: Backend> PartialWithExtraNet<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: PartialWithExtraNetRecord<B>) -> Self {
|
||||
let conv1 = ConvBlock {
|
||||
conv: Conv2dConfig::new([2, 4], [3, 2]).init_with(record.conv1.conv),
|
||||
norm: BatchNormConfig::new(2).init_with(record.conv1.norm),
|
||||
};
|
||||
Self {
|
||||
conv1,
|
||||
extra_field: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.conv1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
fn model_test(record: NetRecord<TestBackend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
let model = Net::<TestBackend>::new_with(record);
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 2>::from_data(
|
||||
[[
|
||||
-2.306_613,
|
||||
-2.058_945_4,
|
||||
-2.298_372_7,
|
||||
-2.358_294,
|
||||
-2.296_395_5,
|
||||
-2.416_090_5,
|
||||
-2.107_669,
|
||||
-2.428_420_8,
|
||||
-2.526_469,
|
||||
-2.319_918_6,
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_record() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/complex_nested/complex_nested.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
model_test(record, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn half_record() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/complex_nested/complex_nested.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
model_test(record, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_model_loading() {
|
||||
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
|
||||
let load_args = LoadArgs::new("tests/complex_nested/complex_nested.pt".into())
|
||||
.with_key_remap("conv_blocks\\.0\\.(.*)", "conv1.$1");
|
||||
|
||||
let device = Default::default();
|
||||
// Load the partial record from the full model
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = PartialNet::<TestBackend>::new_with(record);
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
// get the sum of all elements in the output tensor for quick check
|
||||
let sum = output.sum();
|
||||
|
||||
assert_eq!(4.871538, sum.into_scalar());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extra_field_model_loading() {
|
||||
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
|
||||
let load_args = LoadArgs::new("tests/complex_nested/complex_nested.pt".into())
|
||||
.with_key_remap("conv_blocks\\.0\\.(.*)", "conv1.$1");
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
// Load the partial record from the full model
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = PartialWithExtraNet::<TestBackend>::new_with(record);
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
// get the sum of all elements in the output tensor for quick check
|
||||
let sum = output.sum();
|
||||
|
||||
assert_eq!(4.871538, sum.into_scalar());
|
||||
|
||||
assert!(model.extra_field);
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv1d(2, 2, 2)
|
||||
self.conv2 = nn.Conv1d(2, 2, 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv1d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 6)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,95 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv1d, Conv1dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv1d<B>,
|
||||
conv2: Conv1d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let conv1 = Conv1dConfig::new(2, 2, 2).init_with(record.conv1);
|
||||
let conv2 = Conv1dConfig::new(2, 2, 2)
|
||||
.with_bias(false)
|
||||
.init_with(record.conv2);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv1d(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 3>::from_data(
|
||||
[[
|
||||
[
|
||||
0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965,
|
||||
],
|
||||
[
|
||||
0.33977574,
|
||||
0.523_940_8,
|
||||
0.798_063_9,
|
||||
0.77176833,
|
||||
0.01122457,
|
||||
0.80996025,
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 3>::from_data(
|
||||
[[
|
||||
[0.02987457, 0.03134188, 0.04234261, -0.02437721],
|
||||
[-0.03788019, -0.02972012, -0.00806090, -0.01981254],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_full_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/conv1d/conv1d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv1d(record, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_half_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/conv1d/conv1d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv1d(record, 4);
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv2d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,131 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init_with(record.conv1);
|
||||
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init_with(record.conv2);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv2d(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[
|
||||
0.024_595_8,
|
||||
0.25883394,
|
||||
0.93905586,
|
||||
0.416_715_5,
|
||||
0.713_979_7,
|
||||
],
|
||||
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
|
||||
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
|
||||
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
|
||||
[
|
||||
0.99802476,
|
||||
0.900_794_2,
|
||||
0.476_588_2,
|
||||
0.16625845,
|
||||
0.804_481_1,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
0.943_447_5,
|
||||
],
|
||||
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
|
||||
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
|
||||
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
|
||||
[
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
0.414_796_4,
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[-0.02502128, 0.00250649, 0.04841233],
|
||||
[0.04589614, -0.00296854, 0.01991477],
|
||||
[0.02920526, 0.059_497_3, 0.04326791],
|
||||
],
|
||||
[
|
||||
[-0.04825336, 0.080_190_9, -0.02375088],
|
||||
[0.02885434, 0.09638263, -0.07460806],
|
||||
[0.02004079, 0.06244051, 0.035_887_1],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_full_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/conv2d/conv2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv2d(record, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_half_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/conv2d/conv2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv2d(record, 4);
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.ConvTranspose1d(2, 2, 2)
|
||||
self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv_transpose1d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 2)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,81 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{ConvTranspose1d, ConvTranspose1dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: ConvTranspose1d<B>,
|
||||
conv2: ConvTranspose1d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init_with(record.conv1);
|
||||
let conv2 = ConvTranspose1dConfig::new([2, 2], 2).init_with(record.conv2);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv_transpose1d(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 3>::from_data(
|
||||
[[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 3>::from_data(
|
||||
[[
|
||||
[0.02935525, 0.01119324, -0.01356167, -0.00682688],
|
||||
[0.01644749, -0.01429807, 0.00083987, 0.00279229],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_full() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/conv_transpose1d/conv_transpose1d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose1d(record, 8);
|
||||
}
|
||||
#[test]
|
||||
fn conv_transpose1d_half() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/conv_transpose1d/conv_transpose1d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose1d(record, 4);
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2))
|
||||
self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv_transpose2d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 2, 2)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,94 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{ConvTranspose2d, ConvTranspose2dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: ConvTranspose2d<B>,
|
||||
conv2: ConvTranspose2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init_with(record.conv1);
|
||||
let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init_with(record.conv2);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv_transpose2d(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
|
||||
[[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.04547675, 0.01879685, -0.01636661, 0.00310803],
|
||||
[0.02090115, 0.01192738, -0.048_240_2, 0.02252235],
|
||||
[0.03249975, -0.00460748, 0.05003899, 0.04029131],
|
||||
[0.02185687, -0.10226749, -0.06508022, -0.01267705],
|
||||
],
|
||||
[
|
||||
[0.00277598, -0.00513832, -0.059_048_3, 0.00567626],
|
||||
[-0.03149522, -0.195_757_4, 0.03474613, 0.01997269],
|
||||
[-0.10096474, 0.00679589, 0.041_919_7, -0.02464108],
|
||||
[-0.03174751, 0.02963913, -0.02703723, -0.01860938],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose2d_full() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/conv_transpose2d/conv_transpose2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose2d(record, 7);
|
||||
}
|
||||
#[test]
|
||||
fn conv_transpose2d_half() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/conv_transpose2d/conv_transpose2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose2d(record, 4);
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.embed = nn.Embedding(10, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "embedding.pt")
|
||||
|
||||
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,84 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{Embedding, EmbeddingConfig},
|
||||
tensor::{backend::Backend, Int, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
embed: Embedding<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let embed = EmbeddingConfig::new(10, 3).init_with(record.embed);
|
||||
Self { embed }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
self.embed.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn embedding(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 2, Int>::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 3>::from_data(
|
||||
[
|
||||
[
|
||||
[-1.609_484_9, -0.10016718, -0.609_188_9],
|
||||
[-0.97977227, -1.609_096_3, -0.712_144_6],
|
||||
[-0.22227049, 1.687_113_4, -0.32062083],
|
||||
[-0.29934573, 1.879_345_7, -0.07213178],
|
||||
],
|
||||
[
|
||||
[-0.22227049, 1.687_113_4, -0.32062083],
|
||||
[0.303_722, -0.777_314_3, -0.25145486],
|
||||
[-0.97977227, -1.609_096_3, -0.712_144_6],
|
||||
[-0.02878714, 2.357_111, -1.037_338_7],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_full_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/embedding/embedding.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
embedding(record, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_half_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/embedding/embedding.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
embedding(record, 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm1 = nn.GroupNorm(2, 6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "group_norm.pt")
|
||||
|
||||
x2 = torch.rand(1, 6, 2, 2)
|
||||
print("Input shape: {}", x2.shape)
|
||||
print("Input: {}", x2)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,88 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{GroupNorm, GroupNormConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
norm1: GroupNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let norm1 = GroupNormConfig::new(2, 6).init_with(record.norm1);
|
||||
Self { norm1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.norm1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn group_norm(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
|
||||
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
|
||||
[[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]],
|
||||
[[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]],
|
||||
[[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]],
|
||||
[[0.31379688, 0.19801933], [0.41619217, 0.28432965]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]],
|
||||
[[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]],
|
||||
[[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]],
|
||||
[[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]],
|
||||
[[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]],
|
||||
[[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_full() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/group_norm/group_norm.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
group_norm(record, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_half() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/group_norm/group_norm.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
group_norm(record, 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
buffer = torch.tensor([1, 2, 3])
|
||||
self.register_buffer("buffer", buffer, persistent=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.buffer
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "integer.pt")
|
||||
|
||||
input = torch.ones(3, 3)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,69 @@
|
|||
use burn::{
|
||||
module::{Module, Param},
|
||||
tensor::{backend::Backend, Int, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
buffer: Param<Tensor<B, 1, Int>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
Self {
|
||||
buffer: record.buffer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Int> {
|
||||
self.buffer.val()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
use burn::{
|
||||
record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder},
|
||||
tensor::Data,
|
||||
};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn integer(record: NetRecord<Backend>, _precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 2>::ones([3, 3], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 1, Int>::from_data(Data::from([1, 2, 3]), &device);
|
||||
|
||||
assert_eq!(output.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integer_full_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/integer/integer.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
integer(record, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integer_half_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/integer/integer.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
integer(record, 0);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(ConvModule, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv = ConvModule()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "key_remap.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,116 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init_with(record.conv1);
|
||||
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init_with(record.conv2);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn key_remap() {
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[
|
||||
0.024_595_8,
|
||||
0.25883394,
|
||||
0.93905586,
|
||||
0.416_715_5,
|
||||
0.713_979_7,
|
||||
],
|
||||
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
|
||||
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
|
||||
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
|
||||
[
|
||||
0.99802476,
|
||||
0.900_794_2,
|
||||
0.476_588_2,
|
||||
0.16625845,
|
||||
0.804_481_1,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
0.943_447_5,
|
||||
],
|
||||
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
|
||||
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
|
||||
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
|
||||
[
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
0.414_796_4,
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[-0.02502128, 0.00250649, 0.04841233],
|
||||
[0.04589614, -0.00296854, 0.01991477],
|
||||
[0.02920526, 0.059_497_3, 0.04326791],
|
||||
],
|
||||
[
|
||||
[-0.04825336, 0.080_190_9, -0.02375088],
|
||||
[0.02885434, 0.09638263, -0.07460806],
|
||||
[0.02004079, 0.06244051, 0.035_887_1],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected.to_data(), 7);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm1 = nn.LayerNorm(2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "layer_norm.pt")
|
||||
|
||||
x2 = torch.rand(1, 2, 2, 2)
|
||||
print("Input shape: {}", x2.shape)
|
||||
print("Input: {}", x2)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,79 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{LayerNorm, LayerNormConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
norm1: LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let norm1 = LayerNormConfig::new(4).init_with(record.norm1);
|
||||
Self { norm1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.norm1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn layer_norm(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
|
||||
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]],
|
||||
[[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_full() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/layer_norm/layer_norm.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
layer_norm(record, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_half() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/layer_norm/layer_norm.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
layer_norm(record, 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.fc1 = nn.Linear(2, 3)
|
||||
self.fc2 = nn.Linear(3, 4, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModelWithBias(nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelWithBias, self).__init__()
|
||||
self.fc1 = nn.Linear(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
model_with_bias = ModelWithBias().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "linear.pt")
|
||||
torch.save(model_with_bias.state_dict(), "linear_with_bias.pt")
|
||||
|
||||
input = torch.rand(1, 2, 2, 2)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
print("Model with bias")
|
||||
output = model_with_bias(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,149 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig, ReLU},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
relu: ReLU,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
|
||||
let fc2 = LinearConfig::new(3, 4).init_with(record.fc2);
|
||||
let relu = ReLU::default();
|
||||
|
||||
Self { fc1, fc2, relu }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.fc1.forward(x);
|
||||
let x = self.relu.forward(x);
|
||||
|
||||
self.fc2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct NetWithBias<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> NetWithBias<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn new_with(record: NetWithBiasRecord<B>) -> Self {
|
||||
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
|
||||
|
||||
Self { fc1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.fc1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder};
|
||||
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn linear_test(record: NetRecord<Backend>, precision: usize) {
|
||||
let device = Default::default();
|
||||
let model = Net::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
|
||||
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.09778349, -0.13756673, 0.04962806, 0.08856435],
|
||||
[0.03163241, -0.02848549, 0.01437942, 0.11905234],
|
||||
],
|
||||
[
|
||||
[0.07628226, -0.10757702, 0.03656857, 0.03824598],
|
||||
[0.05443089, -0.06904714, 0.02744314, 0.09997337],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected.to_data(), precision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_full_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/linear/linear.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
linear_test(record, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_half_precision() {
|
||||
let device = Default::default();
|
||||
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::default()
|
||||
.load("tests/linear/linear.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
linear_test(record, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_with_bias() {
|
||||
let device = Default::default();
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/linear/linear_with_bias.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = NetWithBias::<Backend>::new_with(record);
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
|
||||
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<Backend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[-0.00432095, -1.107_101_2, 0.870_691_4],
|
||||
[0.024_595_5, -0.954_462_9, 0.48518157],
|
||||
],
|
||||
[
|
||||
[0.34315687, -0.757_384_2, 0.548_288],
|
||||
[-0.06608963, -1.072_072_7, 0.645_800_5],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected.to_data(), 6);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
cfg_if::cfg_if! {
|
||||
if #[cfg(not(target_os = "windows"))] {
|
||||
// The crate is not supported on Windows because of Candle's pt bug on Windows
|
||||
// (see https://github.com/huggingface/candle/issues/1454).
|
||||
mod batch_norm;
|
||||
mod boolean;
|
||||
mod buffer;
|
||||
mod complex_nested;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod embedding;
|
||||
mod group_norm;
|
||||
mod integer;
|
||||
mod key_remap;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
}
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
#[cfg(feature = "onnx")]
|
||||
use burn_import::onnx::{ModelGen, RecordType};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
/// Takes an ONNX file and generates a model from it
|
||||
fn main() {
|
||||
let onnx_file = std::env::args().nth(1).expect("No input file provided");
|
||||
|
@ -15,3 +17,8 @@ fn main() {
|
|||
.out_dir(output_dir.as_str())
|
||||
.run_from_cli();
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "onnx"))]
|
||||
fn main() {
|
||||
println!("Compiled without Pytorch feature.");
|
||||
}
|
|
@ -50,7 +50,7 @@ pub struct BurnGraph<PS: PrecisionSettings> {
|
|||
}
|
||||
|
||||
// The backend used for recording.
|
||||
type Backend = burn_ndarray::NdArray;
|
||||
type Backend = burn::backend::ndarray::NdArray;
|
||||
|
||||
impl<PS: PrecisionSettings> BurnGraph<PS> {
|
||||
/// Register a new operation node into the graph.
|
||||
|
@ -400,7 +400,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
pub fn from_embedded(device: &B::Device) -> Self {
|
||||
let record = BinBytesRecorder::<#precision_ty>::default()
|
||||
.load(EMBEDDED_STATES.to_vec(), device)
|
||||
.expect("Failed to decode state");
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
Self::new_with(record)
|
||||
}
|
||||
|
|
|
@ -6,8 +6,8 @@ use super::{
|
|||
max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
use burn::backend::NdArray;
|
||||
use burn::record::PrecisionSettings;
|
||||
use burn_ndarray::NdArray;
|
||||
use proc_macro2::TokenStream;
|
||||
use serde::Serialize;
|
||||
|
||||
|
|
|
@ -9,16 +9,25 @@
|
|||
//! aligns the imported model with Burn's model and converts tensor data into a format compatible with
|
||||
//! Burn.
|
||||
|
||||
#[cfg(any(feature = "pytorch", feature = "onnx"))]
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
// Enabled when the `pytorch` or `onnx` feature is enabled.
|
||||
#[cfg(any(feature = "pytorch", feature = "onnx"))]
|
||||
mod logger;
|
||||
|
||||
/// The onnx module.
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod onnx;
|
||||
|
||||
/// The module for generating the burn code.
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod burn;
|
||||
|
||||
/// The PyTorch module for recorder.
|
||||
#[cfg(feature = "pytorch")]
|
||||
pub mod pytorch;
|
||||
|
||||
mod formatter;
|
||||
mod logger;
|
||||
pub use formatter::*;
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use std::error::Error;
|
||||
use tracing_core::LevelFilter;
|
||||
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
use burn::{
|
||||
module::Param,
|
||||
record::{PrecisionSettings, Record},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use burn::record::serde::{
|
||||
adapter::{BurnModuleAdapter, DefaultAdapter},
|
||||
data::NestedValue,
|
||||
ser::Serializer,
|
||||
};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
/// A PyTorch adapter for the Burn module used during deserialization.
|
||||
///
|
||||
/// Not all Burn module correspond to a PyTorch module. Therefore,
|
||||
/// we need to adapt the Burn module to a PyTorch module. We implement
|
||||
/// only those that differ.
|
||||
pub struct PyTorchAdapter<PS: PrecisionSettings, B: Backend> {
|
||||
_precision_settings: std::marker::PhantomData<(PS, B)>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS, B> {
|
||||
fn adapt_linear(data: NestedValue) -> NestedValue {
|
||||
// Get the current module in the form of map.
|
||||
let mut map = data.as_map().expect("Failed to get map from NestedValue");
|
||||
|
||||
// Get/remove the weight parameter.
|
||||
let weight = map
|
||||
.remove("weight")
|
||||
.expect("Failed to find 'weight' key in map");
|
||||
|
||||
// Convert the weight parameter to a tensor (use default device, since it's quick operation).
|
||||
let weight: Param<Tensor<B, 2>> = weight
|
||||
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
|
||||
.expect("Failed to deserialize weight");
|
||||
|
||||
// Transpose the weight tensor.
|
||||
let weight_transposed = Param::from(weight.val().transpose());
|
||||
|
||||
// Insert the transposed weight tensor back into the map.
|
||||
map.insert(
|
||||
"weight".to_owned(),
|
||||
serialize::<PS, _, 2>(weight_transposed),
|
||||
);
|
||||
|
||||
// Return the modified map.
|
||||
NestedValue::Map(map)
|
||||
}
|
||||
|
||||
fn adapt_group_norm(data: NestedValue) -> NestedValue {
|
||||
rename_weight_bias(data)
|
||||
}
|
||||
|
||||
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
|
||||
rename_weight_bias(data)
|
||||
}
|
||||
|
||||
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
|
||||
rename_weight_bias(data)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to serialize a param tensor.
|
||||
fn serialize<PS, B, const D: usize>(val: Param<Tensor<B, D>>) -> NestedValue
|
||||
where
|
||||
B: Backend,
|
||||
PS: PrecisionSettings,
|
||||
{
|
||||
let serializer = Serializer::new();
|
||||
|
||||
val.into_item::<PS>()
|
||||
.serialize(serializer)
|
||||
.expect("Failed to serialize the item")
|
||||
}
|
||||
|
||||
/// Helper function to rename the weight and bias parameters to gamma and beta.
|
||||
///
|
||||
/// This is needed because PyTorch uses different names for the normalizer parameter
|
||||
/// than Burn. Burn uses gamma and beta, while PyTorch uses weight and bias.
|
||||
fn rename_weight_bias(data: NestedValue) -> NestedValue {
|
||||
// Get the current module in the form of map.
|
||||
let mut map = data.as_map().expect("Failed to get map from NestedValue");
|
||||
|
||||
// Rename the weight parameter to gamma.
|
||||
let weight = map
|
||||
.remove("weight")
|
||||
.expect("Failed to find 'weight' key in map");
|
||||
|
||||
map.insert("gamma".to_owned(), weight);
|
||||
|
||||
// Rename the bias parameter to beta.
|
||||
let bias = map
|
||||
.remove("bias")
|
||||
.expect("Failed to find 'bias' key in map");
|
||||
|
||||
map.insert("beta".to_owned(), bias);
|
||||
|
||||
// Return the modified map.
|
||||
NestedValue::Map(map)
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
use burn::record::{serde::error, RecorderError};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Serde error: {0}")]
|
||||
Serde(#[from] error::Error),
|
||||
|
||||
#[error("Candle pickle error: {0}")]
|
||||
CandlePickle(#[from] candle_core::Error),
|
||||
|
||||
// Add other kinds of errors as needed
|
||||
#[error("other error: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
// Implement From trait for Error to RecorderError
|
||||
impl From<Error> for RecorderError {
|
||||
fn from(error: Error) -> Self {
|
||||
RecorderError::DeserializeError(error.to_string())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
mod adapter;
|
||||
mod error;
|
||||
mod reader;
|
||||
mod recorder;
|
||||
|
||||
pub use recorder::{LoadArgs, PyTorchFileRecorder};
|
|
@ -0,0 +1,129 @@
|
|||
use core::ops::Deref;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use super::{adapter::PyTorchAdapter, error::Error};
|
||||
|
||||
use burn::{
|
||||
module::ParamId,
|
||||
record::{ParamSerde, PrecisionSettings},
|
||||
tensor::{DataSerialize, Element, ElementConversion},
|
||||
};
|
||||
use burn::{
|
||||
record::serde::{
|
||||
data::{remap, unflatten, NestedValue, Serializable},
|
||||
de::Deserializer,
|
||||
error,
|
||||
ser::Serializer,
|
||||
},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
use candle_core::{pickle, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use regex::Regex;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Deserializes a PyTorch file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - A string slice that holds the path of the file to read.
|
||||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||
pub fn from_file<PS, D, B>(path: &Path, key_remap: Vec<(Regex, String)>) -> Result<D, Error>
|
||||
where
|
||||
D: DeserializeOwned,
|
||||
PS: PrecisionSettings,
|
||||
B: Backend,
|
||||
{
|
||||
// Read the pickle file and return a vector of Candle tensors
|
||||
let tensors: HashMap<String, CandleTensor> = pickle::read_all(path)?
|
||||
.into_iter()
|
||||
.map(|(key, tensor)| (key, CandleTensor(tensor)))
|
||||
.collect();
|
||||
|
||||
// Remap the keys (replace the keys in the map with the new keys)
|
||||
let tensors = remap(tensors, key_remap);
|
||||
|
||||
// Convert the vector of Candle tensors to a nested value data structure
|
||||
let nested_value = unflatten::<PS, _>(tensors)?;
|
||||
|
||||
// Create a deserializer with PyTorch adapter and nested value
|
||||
let deserializer = Deserializer::<PyTorchAdapter<PS, B>>::new(nested_value, true);
|
||||
|
||||
// Deserialize the nested value into a record type
|
||||
let value = D::deserialize(deserializer)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
/// Serializes a candle tensor.
|
||||
///
|
||||
/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `DataSerialize` struct.
|
||||
///
|
||||
/// Values are serialized as `FloatElem` or `IntElem` depending on the precision settings.
|
||||
impl Serializable for CandleTensor {
|
||||
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, error::Error>
|
||||
where
|
||||
PS: PrecisionSettings,
|
||||
{
|
||||
let shape = self.shape().clone().into_dims();
|
||||
let flatten = CandleTensor(self.flatten_all().expect("Failed to flatten the tensor"));
|
||||
let param_id = ParamId::new().into_string();
|
||||
|
||||
match self.dtype() {
|
||||
candle_core::DType::U8 => {
|
||||
serialize_data::<u8, PS::IntElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::U32 => {
|
||||
serialize_data::<u32, PS::IntElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::I64 => {
|
||||
serialize_data::<i64, PS::IntElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::BF16 => {
|
||||
serialize_data::<bf16, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::F16 => {
|
||||
serialize_data::<f16, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::F32 => {
|
||||
serialize_data::<f32, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::F64 => {
|
||||
serialize_data::<f64, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to serialize a candle tensor data.
|
||||
fn serialize_data<T, E>(
|
||||
tensor: CandleTensor,
|
||||
shape: Vec<usize>,
|
||||
param_id: String,
|
||||
serializer: Serializer,
|
||||
) -> Result<NestedValue, error::Error>
|
||||
where
|
||||
E: Element + Serialize,
|
||||
T: WithDType + ElementConversion,
|
||||
{
|
||||
let data: Vec<E> = tensor
|
||||
.to_vec1::<T>()
|
||||
.map_err(|err| error::Error::Other(format!("Candle to vec1 error: {err}")))?
|
||||
.into_iter()
|
||||
.map(ElementConversion::elem)
|
||||
.collect();
|
||||
|
||||
ParamSerde::new(param_id, DataSerialize::new(data, shape)).serialize(serializer)
|
||||
}
|
||||
|
||||
/// New type struct for Candle tensors because we need to implement the `Serializable` trait for it.
|
||||
struct CandleTensor(candle_core::Tensor);
|
||||
|
||||
impl Deref for CandleTensor {
|
||||
type Target = candle_core::Tensor;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
use core::marker::PhantomData;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use burn::{
|
||||
record::{PrecisionSettings, Record, Recorder, RecorderError},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
use regex::Regex;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
use super::reader::from_file;
|
||||
|
||||
/// A recorder that that loads PyTorch files (`.pt`) into Burn modules.
|
||||
///
|
||||
/// LoadArgs can be used to remap keys or file path.
|
||||
/// See [LoadArgs](struct.LoadArgs.html) for more information.
|
||||
///
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct PyTorchFileRecorder<PS: PrecisionSettings> {
|
||||
_settings: PhantomData<PS>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS> {
|
||||
type Settings = PS;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = LoadArgs;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
_item: I,
|
||||
_file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
unimplemented!("save_item not implemented for PyTorchFileRecorder")
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(&self, _file: Self::LoadArgs) -> Result<I, RecorderError> {
|
||||
unimplemented!("load_item not implemented for PyTorchFileRecorder")
|
||||
}
|
||||
|
||||
fn load<R: Record<B>>(
|
||||
&self,
|
||||
args: Self::LoadArgs,
|
||||
device: &B::Device,
|
||||
) -> Result<R, RecorderError> {
|
||||
let item = from_file::<PS, R::Item<Self::Settings>, B>(&args.file, args.key_remap)?;
|
||||
Ok(R::from_item(item, device))
|
||||
}
|
||||
}
|
||||
|
||||
/// Arguments for loading a PyTorch file.
|
||||
///
|
||||
/// # Fields
|
||||
///
|
||||
/// * `file` - The path to the file to load.
|
||||
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
|
||||
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
||||
/// for more information.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Use [Netron](https://github.com/lutzroeder/netron) to inspect the keys of the PyTorch file (.pt extension).
|
||||
///
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```text
|
||||
/// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
/// use burn::record::FullPrecisionSettings;
|
||||
/// use burn::record::Recorder;
|
||||
///
|
||||
/// let args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
/// .with_key_remap("conv\\.(.*)", "$1"); // // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
///
|
||||
/// let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
/// .load(args)
|
||||
/// .expect("Should decode state successfully");
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadArgs {
|
||||
/// The path to the file to load.
|
||||
pub file: PathBuf,
|
||||
|
||||
/// A list of key remappings.
|
||||
pub key_remap: Vec<(Regex, String)>,
|
||||
}
|
||||
|
||||
impl LoadArgs {
|
||||
/// Create a new `LoadArgs` instance.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - The path to the file to load.
|
||||
pub fn new(file: PathBuf) -> Self {
|
||||
Self {
|
||||
file,
|
||||
key_remap: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set key remapping.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pattern` - The Regex pattern to be replaced.
|
||||
/// * `replacement` - The pattern to replace with.
|
||||
///
|
||||
/// See [Regex](https://docs.rs/regex/1.5.4/regex/#syntax) for the pattern syntax and
|
||||
/// [Replacement](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) for the
|
||||
/// replacement syntax.
|
||||
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
|
||||
let regex = Regex::new(&format!("^{}$", pattern)).unwrap();
|
||||
|
||||
self.key_remap.push((regex, replacement.into()));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LoadArgs {
|
||||
fn from(val: PathBuf) -> Self {
|
||||
LoadArgs::new(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LoadArgs {
|
||||
fn from(val: String) -> Self {
|
||||
LoadArgs::new(val.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for LoadArgs {
|
||||
fn from(val: &str) -> Self {
|
||||
LoadArgs::new(val.into())
|
||||
}
|
||||
}
|
|
@ -53,6 +53,9 @@ candle = ["burn-core/candle"]
|
|||
# Experimental
|
||||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||
|
||||
# Records
|
||||
record-item-custom-serde = ["burn-core/record-item-custom-serde"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
[package]
|
||||
authors = ["Dilshod Tadjibaev (@antimora)"]
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
name = "pytorch-import"
|
||||
publish = false
|
||||
version = "0.12.0"
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn", features = [
|
||||
"ndarray",
|
||||
"dataset",
|
||||
"sqlite-bundled",
|
||||
] }
|
||||
|
||||
model = { path = "./model" }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
model = { path = "./model" }
|
||||
burn = { path = "../../burn", features = ["ndarray"] }
|
||||
burn-import = { path = "../../burn-import", features = [
|
||||
"pytorch",
|
||||
], default-features = false }
|
|
@ -0,0 +1,29 @@
|
|||
# Import PyTorch Weights
|
||||
|
||||
This crate provides a simple example for importing PyTorch generated weights to Burn.
|
||||
|
||||
The `.pt` file is converted into a Burn consumable file (message pack format) using `burn-import`.
|
||||
The conversation is done in the `build.rs` file.
|
||||
|
||||
The model is separated into a sub-crate because `build.rs` needs for conversion and build.rs cannot
|
||||
import modules for the same crate.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
cargo run -- 15
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```bash
|
||||
Finished dev [unoptimized + debuginfo] target(s) in 0.13s
|
||||
Running `burn/target/debug/onnx-inference 15`
|
||||
|
||||
Image index: 15
|
||||
Success!
|
||||
Predicted: 5
|
||||
Actual: 5
|
||||
See the image online, click the link below:
|
||||
https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/15/image/image.jpg
|
||||
```
|
|
@ -0,0 +1,41 @@
|
|||
/// This build script does the following:
|
||||
/// 1. Loads PyTorch weights into a model record.
|
||||
/// 2. Saves the model record to a file using the `NamedMpkFileRecorder`.
|
||||
use std::path::Path;
|
||||
|
||||
use burn::backend::NdArray;
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
// Basic backend type (not used directly here).
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn main() {
|
||||
if cfg!(target_os = "windows") {
|
||||
println!(
|
||||
"{}",
|
||||
"cargo:warning=The crate is not supported on Windows because of ".to_owned()
|
||||
+ "Candle's pt bug on Windows "
|
||||
+ "(see https://github.com/huggingface/candle/issues/1454)."
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
// Load PyTorch weights into a model record.
|
||||
let record: model::ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("pytorch/mnist.pt".into(), &device)
|
||||
.expect("Failed to decode state");
|
||||
|
||||
// Save the model record to a file.
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
|
||||
|
||||
// Save into the OUT_DIR directory so that the model can be loaded by the
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
let file_path = Path::new(&out_dir).join("model/mnist");
|
||||
|
||||
recorder
|
||||
.record(record, file_path)
|
||||
.expect("Failed to save model record");
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
[package]
|
||||
name = "model"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../burn" }
|
||||
burn-import = { path = "../../../burn-import", features = [
|
||||
"pytorch",
|
||||
], default-features = false }
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
use burn::nn::conv::Conv2d;
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
use burn::nn::BatchNorm;
|
||||
use burn::nn::BatchNormConfig;
|
||||
use burn::nn::Linear;
|
||||
use burn::nn::LinearConfig;
|
||||
use burn::record::FullPrecisionSettings;
|
||||
use burn::record::NamedMpkFileRecorder;
|
||||
use burn::record::Recorder;
|
||||
use burn::tensor::activation::log_softmax;
|
||||
use burn::tensor::activation::relu;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
conv3: Conv2d<B>,
|
||||
norm1: BatchNorm<B, 2>,
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
norm2: BatchNorm<B, 0>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for Model<B> {
|
||||
fn default() -> Self {
|
||||
let out_dir = env::var_os("OUT_DIR").unwrap();
|
||||
let file_path = Path::new(&out_dir).join("model/mnist");
|
||||
|
||||
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(file_path, &B::Device::default())
|
||||
.expect("Failed to decode state");
|
||||
|
||||
Self::new_with(record)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let conv1 = Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1);
|
||||
let conv2 = Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2);
|
||||
let conv3 = Conv2dConfig::new([16, 24], [3, 3]).init_with(record.conv3);
|
||||
let norm1 = BatchNormConfig::new(24).init_with(record.norm1);
|
||||
let fc1 = LinearConfig::new(11616, 32).init_with(record.fc1);
|
||||
let fc2 = LinearConfig::new(32, 10).init_with(record.fc2);
|
||||
let norm2 = BatchNormConfig::new(10).init_with(record.norm2);
|
||||
Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
norm1,
|
||||
fc1,
|
||||
fc2,
|
||||
norm2,
|
||||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 2> {
|
||||
let conv1_out1 = self.conv1.forward(input1);
|
||||
let relu1_out1 = relu(conv1_out1);
|
||||
let conv2_out1 = self.conv2.forward(relu1_out1);
|
||||
let relu2_out1 = relu(conv2_out1);
|
||||
let conv3_out1 = self.conv3.forward(relu2_out1);
|
||||
let relu3_out1 = relu(conv3_out1);
|
||||
let norm1_out1 = self.norm1.forward(relu3_out1);
|
||||
let flatten1_out1 = norm1_out1.flatten(1, 3);
|
||||
let fc1_out1 = self.fc1.forward(flatten1_out1);
|
||||
let relu4_out1 = relu(fc1_out1);
|
||||
let fc2_out1 = self.fc2.forward(relu4_out1);
|
||||
let norm2_out1 = self.norm2.forward(fc2_out1);
|
||||
log_softmax(norm2_out1, 1)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,163 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py
|
||||
# under the following license: BSD-3-Clause license
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 8, 3)
|
||||
self.conv2 = nn.Conv2d(8, 16, 3)
|
||||
self.conv3 = nn.Conv2d(16, 24, 3)
|
||||
self.norm1 = nn.BatchNorm2d(24)
|
||||
self.dropout1 = nn.Dropout(0.3)
|
||||
self.fc1 = nn.Linear(24 * 22 * 22, 32)
|
||||
self.fc2 = nn.Linear(32, 10)
|
||||
self.norm2 = nn.BatchNorm1d(10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv3(x)
|
||||
x = F.relu(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.norm2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
if args.dry_run:
|
||||
break
|
||||
|
||||
|
||||
def test(model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item()
|
||||
# get the index of the max log-probability
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=8, metavar='N',
|
||||
help='number of epochs to train (default: 14)')
|
||||
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
|
||||
help='learning rate (default: 1.0)')
|
||||
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
|
||||
help='Learning rate step gamma (default: 0.7)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--no-mps', action='store_true', default=False,
|
||||
help='disables macOS GPU training')
|
||||
parser.add_argument('--dry-run', action='store_true', default=False,
|
||||
help='quickly check a single pass')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--save-model', action='store_true', default=True,
|
||||
help='For Saving the current Model')
|
||||
parser.add_argument('--export-onnx', action='store_true', default=False,
|
||||
help='For Saving the current Model in ONNX format')
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
use_mps = not args.no_mps and torch.backends.mps.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if use_cuda:
|
||||
device = torch.device("cuda")
|
||||
elif use_mps:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
train_kwargs = {'batch_size': args.batch_size}
|
||||
test_kwargs = {'batch_size': args.test_batch_size}
|
||||
if use_cuda:
|
||||
cuda_kwargs = {'num_workers': 1,
|
||||
'pin_memory': True,
|
||||
'shuffle': True}
|
||||
train_kwargs.update(cuda_kwargs)
|
||||
test_kwargs.update(cuda_kwargs)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
dataset1 = datasets.MNIST('/tmp/mnist-data', train=True, download=True,
|
||||
transform=transform)
|
||||
dataset2 = datasets.MNIST('/tmp/mnist-data', train=False,
|
||||
transform=transform)
|
||||
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
|
||||
|
||||
model = Net().to(device)
|
||||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
|
||||
|
||||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(model, device, test_loader)
|
||||
scheduler.step()
|
||||
|
||||
if args.save_model:
|
||||
torch.save(model.state_dict(), "mnist.pt")
|
||||
|
||||
if args.export_onnx:
|
||||
dummy_input = torch.randn(1, 1, 28, 28, device=device)
|
||||
torch.onnx.export(model, dummy_input, "mnist.onnx",
|
||||
verbose=True, opset_version=16)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,74 @@
|
|||
use std::env::args;
|
||||
use std::path::Path;
|
||||
|
||||
use burn::backend::ndarray::NdArray;
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
|
||||
use model::Model;
|
||||
|
||||
const IMAGE_INX: usize = 42; // <- Change this to test a different image
|
||||
|
||||
// Build output direct that contains converted model weight file path
|
||||
const OUT_DIR: &str = concat!(env!("OUT_DIR"), "/model/mnist");
|
||||
|
||||
fn main() {
|
||||
// Get image index argument (first) from command line
|
||||
|
||||
let image_index = if let Some(image_index) = args().nth(1) {
|
||||
println!("Image index: {}", image_index);
|
||||
image_index
|
||||
.parse::<usize>()
|
||||
.expect("Failed to parse image index")
|
||||
} else {
|
||||
println!("No image index provided; Using default image index: {IMAGE_INX}");
|
||||
IMAGE_INX
|
||||
};
|
||||
|
||||
assert!(image_index < 10000, "Image index must be less than 10000");
|
||||
|
||||
type Backend = NdArray<f32>;
|
||||
let device = Default::default();
|
||||
|
||||
// Load the model record from converted PyTorch file by the build script
|
||||
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(Path::new(OUT_DIR).into(), &device)
|
||||
.expect("Failed to decode state");
|
||||
|
||||
// Create a new model and load the state
|
||||
let model: Model<Backend> = Model::new_with(record);
|
||||
|
||||
// Load the MNIST dataset and get an item
|
||||
let dataset = MNISTDataset::test();
|
||||
let item = dataset.get(image_index).unwrap();
|
||||
|
||||
// Create a tensor from the image data
|
||||
let image_data = item.image.iter().copied().flatten().collect::<Vec<f32>>();
|
||||
let mut input: Tensor<Backend, 4> =
|
||||
Tensor::from_floats(image_data.as_slice(), &device).reshape([1, 1, 28, 28]);
|
||||
|
||||
// Normalize the input
|
||||
input = ((input / 255) - 0.1307) / 0.3081;
|
||||
|
||||
// Run the model on the input
|
||||
let output = model.forward(input);
|
||||
|
||||
// Get the index of the maximum value
|
||||
let arg_max = output.argmax(1).into_scalar() as usize;
|
||||
|
||||
// Check if the index matches the label
|
||||
assert!(arg_max == item.label);
|
||||
|
||||
println!("Success!");
|
||||
println!("Predicted: {}", arg_max);
|
||||
println!("Actual: {}", item.label);
|
||||
|
||||
// Print the image URL if the image index is less than 100 (the online dataset only has 100 images)
|
||||
if image_index < 100 {
|
||||
println!("See the image online, click the link below:");
|
||||
println!("https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/{image_index}/image/image.jpg");
|
||||
}
|
||||
}
|
|
@ -246,9 +246,18 @@ fn no_std_checks() {
|
|||
|
||||
// Test burn-core with tch and wgpu backend
|
||||
fn burn_core_std() {
|
||||
// Run cargo test --features test-tch
|
||||
group!("Test: burn-core (tch)");
|
||||
cargo_test(["-p", "burn-core", "--features", "test-tch"].into());
|
||||
// Run cargo test --features test-tch, record-item-custom-serde
|
||||
group!("Test: burn-core (tch) and record-item-custom-serde");
|
||||
cargo_test(
|
||||
[
|
||||
"-p",
|
||||
"burn-core",
|
||||
"--features",
|
||||
"test-tch",
|
||||
"record-item-custom-serde",
|
||||
]
|
||||
.into(),
|
||||
);
|
||||
endgroup!();
|
||||
|
||||
// Run cargo test --features test-wgpu
|
||||
|
|
Loading…
Reference in New Issue