State serialization/deserialization overhaul (#247)

This commit is contained in:
Nathaniel Simard 2023-03-23 11:02:46 -04:00 committed by GitHub
parent 00625d1527
commit 6f43d983f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 361 additions and 209 deletions

View File

@ -42,6 +42,9 @@ ndarray = {version = "0.15.6", default-features = false}
num-traits = {version = "0.2.15", default-features = false, features = ["libm"]}# libm is for no_std
rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# std_rng is for no_std
rand_distr = {version = "0.4.3", default-features = false}
uuid = {version = "1.3.0", default-features = false}
serde = {version = "1.0.155", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed
serde_json = {version = "1.0.94", default-features = false}
uuid = {version = "1.3.0", default-features = false}
rmp-serde = {version = "1.1.1"}
bincode = {version = "2.0.0-rc", features=["alloc", "serde"], default-features = false}

View File

@ -13,7 +13,6 @@ version = "0.6.0"
[features]
default = ["std"]
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
std = [
"burn-autodiff",
"burn-common/std",
@ -24,10 +23,19 @@ std = [
"flate2",
"log",
"rand/std",
"serde_json/std",
"serde/std",
"serde_json/std",
"bincode/std",
"half/std",
"half/serde", # TODO: set default when https://github.com/starkat99/half-rs/issues/84 is fixed
]
# Serialization formats
msgpack = ["rmp-serde"] # Assumes std
test-tch = [] # To use tch during testing, default uses ndarray.
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
[dependencies]
@ -52,7 +60,11 @@ hashbrown = {workspace = true, features = ["serde"]}# no_std compatible
# Serialize Deserialize
flate2 = {workspace = true, optional = true}
serde = {workspace = true, features = ["derive"]}
serde_json = {workspace = true, features = ["alloc"]}#Default enables std
rmp-serde = {workspace = true, optional = true}
bincode = {workspace = true}
half = {workspace = true}
[dev-dependencies]
burn-dataset = {path = "../burn-dataset", version = "0.6.0", features = [

View File

@ -1,21 +1,17 @@
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use super::ParamId;
use crate::tensor::{DataSerialize, Element};
use serde::{Deserialize, Serialize};
#[cfg(feature = "std")]
use std::{collections::HashMap, fs::File, path::Path};
#[cfg(feature = "std")]
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[derive(Debug, PartialEq, Eq, Clone, Default, Serialize, Deserialize)]
pub struct StateNamed<E> {
@ -35,6 +31,23 @@ pub enum StateError {
FileNotFound(String),
}
/// All supported format.
///
/// # Notes
///
/// The default file format is compressed bincode for the smallest file size possible.
/// For `no_std` environments, you should use (StateFormat::Bin)[StateFormat::Bin] since compression isn't supported.
/// However, the bincode format alone is smaller than compressed `json` or `msgpack`.
#[derive(Default, Clone)]
pub enum StateFormat {
#[default]
BinGz,
Bin,
JsonGz,
#[cfg(feature = "msgpack")]
MpkGz,
}
impl core::fmt::Display for StateError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut message = "State error => ".to_string();
@ -52,10 +65,6 @@ impl core::fmt::Display for StateError {
}
}
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
#[cfg(feature = "std")]
impl std::error::Error for StateError {}
impl<E: Element> StateNamed<E> {
pub fn new() -> Self {
Self {
@ -113,41 +122,167 @@ impl<E: Element> State<E> {
}
}
#[cfg(feature = "std")]
impl<E: Element> State<E>
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
pub fn save(self, file: &str) -> std::io::Result<()> {
let path = Path::new(file);
if path.exists() {
log::info!("File exists, replacing");
std::fs::remove_file(path).unwrap();
pub fn to_bin(&self) -> Result<Vec<u8>, StateError> {
Ok(bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap())
}
pub fn from_bin(data: &[u8]) -> Result<Self, StateError> {
let state = bincode::serde::decode_borrowed_from_slice(data, Self::bin_config()).unwrap();
Ok(state)
}
fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}
}
#[cfg(feature = "std")]
mod std_enabled {
use super::*;
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use std::{fs::File, path::Path};
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
impl std::error::Error for StateError {}
macro_rules! str2reader {
(
$file:expr,
$ext:expr
) => {{
let path_ref = &format!("{}.{}", $file, $ext);
let path = Path::new(path_ref);
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))
}};
}
macro_rules! str2writer {
(
$file:expr,
$ext:expr
) => {{
let path_ref = &format!("{}.{}", $file, $ext);
let path = Path::new(path_ref);
if path.exists() {
log::info!("File exists, replacing");
std::fs::remove_file(path).unwrap();
}
File::create(path)
}};
}
impl<E: Element> State<E>
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
/// Save the state to the provided file path using the given [StateFormat](StateFormat).
///
/// # Notes
///
/// The file extension will be added automatically depending on the state format.
pub fn save(self, file: &str, format: &StateFormat) -> std::io::Result<()> {
match format {
StateFormat::BinGz => self.save_bingz(file),
StateFormat::Bin => self.save_bin(file),
StateFormat::JsonGz => self.save_jsongz(file),
#[cfg(feature = "msgpack")]
StateFormat::MpkGz => self.save_mpkgz(file),
}
}
let writer = File::create(path)?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &self).unwrap();
/// Load the state from the provided file path using the given [StateFormat](StateFormat).
///
/// # Notes
///
/// The file extension will be added automatically depending on the state format.
pub fn load(file: &str, format: &StateFormat) -> Result<Self, StateError> {
match format {
StateFormat::BinGz => Self::load_bingz(file),
StateFormat::Bin => Self::load_bin(file),
StateFormat::JsonGz => Self::load_jsongz(file),
#[cfg(feature = "msgpack")]
StateFormat::MpkGz => Self::load_mpkgz(file),
}
}
Ok(())
}
fn save_jsongz(self, file: &str) -> std::io::Result<()> {
let writer = str2writer!(file, "json.gz")?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &self).unwrap();
pub fn load(file: &str) -> Result<Self, StateError> {
let path = Path::new(file);
let reader =
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader).unwrap();
Ok(())
}
Ok(state)
}
fn load_jsongz(file: &str) -> Result<Self, StateError> {
let reader = str2reader!(file, "json.gz")?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader).unwrap();
pub fn load_binary(data: &[u8]) -> Result<Self, StateError> {
let reader = GzDecoder::new(data);
let state = serde_json::from_reader(reader).unwrap();
Ok(state)
}
Ok(state)
#[cfg(feature = "msgpack")]
fn save_mpkgz(self, file: &str) -> std::io::Result<()> {
let writer = str2writer!(file, "mpk.gz")?;
let mut writer = GzEncoder::new(writer, Compression::default());
rmp_serde::encode::write(&mut writer, &self).unwrap();
Ok(())
}
#[cfg(feature = "msgpack")]
fn load_mpkgz(file: &str) -> Result<Self, StateError> {
let reader = str2reader!(file, "mpk.gz")?;
let reader = GzDecoder::new(reader);
let state = rmp_serde::decode::from_read(reader).unwrap();
Ok(state)
}
fn save_bingz(self, file: &str) -> std::io::Result<()> {
let config = Self::bin_config();
let writer = str2writer!(file, "bin.gz")?;
let mut writer = GzEncoder::new(writer, Compression::default());
bincode::serde::encode_into_std_write(&self, &mut writer, config).unwrap();
Ok(())
}
fn load_bingz(file: &str) -> Result<Self, StateError> {
let reader = str2reader!(file, "bin.gz")?;
let mut reader = GzDecoder::new(reader);
let state =
bincode::serde::decode_from_std_read(&mut reader, Self::bin_config()).unwrap();
Ok(state)
}
fn save_bin(self, file: &str) -> std::io::Result<()> {
let buf = bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap();
let mut writer = str2writer!(file, "bin")?;
std::io::Write::write_all(&mut writer, &buf).unwrap();
Ok(())
}
fn load_bin(file: &str) -> Result<Self, StateError> {
let mut reader = str2reader!(file, "bin")?;
let mut buf = Vec::new();
std::io::Read::read_to_end(&mut reader, &mut buf).unwrap();
let state =
bincode::serde::decode_borrowed_from_slice(&buf, Self::bin_config()).unwrap();
Ok(state)
}
}
}
@ -170,21 +305,6 @@ mod tests {
assert_eq!(state, state_from);
}
#[cfg(feature = "std")]
#[test]
fn test_can_save_and_load_from_file() {
let model_before = create_model();
let state_before = model_before.state();
state_before.clone().save("/tmp/test.json").unwrap();
let model_after = create_model()
.load(&State::load("/tmp/test.json").unwrap())
.unwrap();
let state_after = model_after.state();
assert_eq!(state_before, state_after);
}
#[test]
fn test_parameter_ids_are_loaded() {
let model_1 = create_model();
@ -200,33 +320,80 @@ mod tests {
assert_eq!(params_before_1, params_after_2);
}
#[cfg(feature = "std")]
#[test]
fn test_load_binary() {
fn test_from_to_binary() {
let model_1 = create_model();
let mut model_2 = create_model();
let model_2 = create_model();
let params_before_1 = list_param_ids(&model_1);
let params_before_2 = list_param_ids(&model_2);
// Write to binary.
let state = model_1.state();
let mut binary = Vec::new();
let writer = GzEncoder::new(&mut binary, Compression::default());
serde_json::to_writer(writer, &state).unwrap();
// Load.
model_2 = model_2.load(&State::load_binary(&binary).unwrap()).unwrap();
let params_after_2 = list_param_ids(&model_2);
// To & From Bytes
let bytes = model_1.state().to_bin().unwrap();
let model_2 = model_2.load(&State::from_bin(&bytes).unwrap()).unwrap();
// Verify.
let params_after_2 = list_param_ids(&model_2);
assert_ne!(params_before_1, params_before_2);
assert_eq!(params_before_1, params_after_2);
}
fn create_model() -> nn::Linear<TestBackend> {
pub fn create_model() -> nn::Linear<TestBackend> {
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig::new(32, 32).with_bias(true))
}
}
#[cfg(all(test, feature = "std"))]
mod tests_save_load {
use super::tests::create_model;
use super::*;
use crate::module::Module;
static FILE_PATH: &str = "/tmp/test_state";
#[test]
fn test_can_save_and_load_from_file_jsongz_format() {
test_can_save_and_load_from_file(StateFormat::JsonGz)
}
#[test]
fn test_can_save_and_load_from_file_bin_format() {
test_can_save_and_load_from_file(StateFormat::Bin)
}
#[test]
fn test_can_save_and_load_from_file_bingz_format() {
test_can_save_and_load_from_file(StateFormat::BinGz)
}
#[cfg(feature = "msgpack")]
#[test]
fn test_can_save_and_load_from_file_mpkgz_format() {
test_can_save_and_load_from_file(StateFormat::MpkGz)
}
#[test]
fn test_from_bin_on_disk() {
let model = create_model();
model
.state()
.save("/tmp/model_compare", &StateFormat::Bin)
.unwrap();
let bytes = std::fs::read("/tmp/model_compare.bin").unwrap();
let state = State::from_bin(&bytes).unwrap();
assert_eq!(state, model.state());
}
fn test_can_save_and_load_from_file(format: StateFormat) {
let model_before = create_model();
let state_before = model_before.state();
state_before.clone().save(FILE_PATH, &format).unwrap();
let model_after = create_model()
.load(&State::load(FILE_PATH, &format).unwrap())
.unwrap();
let state_after = model_after.state();
assert_eq!(state_before, state_after);
}
}

View File

@ -3,8 +3,9 @@ use alloc::vec::Vec;
use core::cmp::Ordering;
use core::ops::Range;
use crate::element::FloatNdArrayElement;
// Current crate
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
use crate::element::FloatNdArrayElement;
use crate::{tensor::NdArrayTensor, NdArrayBackend};
use crate::{NdArrayDevice, SEED};
@ -15,10 +16,9 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
// External crates
use libm::{cos, erf, sin, tanh};
#[cfg(not(feature = "std"))]
use num_traits::Float; // Can't compare two floats with no_std.
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
#[cfg(not(feature = "std"))]
use num_traits::Float;
impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {

View File

@ -17,7 +17,7 @@ doc = ["tch/doc-only"]
burn-tensor = {path = "../burn-tensor", version = "0.6.0"}
libc = "0.2.0"
half = {workspace = true}
half = {workspace = true, features = ["std"]}
rand = {workspace = true, features = ["std"]}
[target.'cfg(not(target_arch = "aarch64"))'.dependencies]

View File

@ -19,6 +19,8 @@ experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"]
std = [
"rand/std",
"half/std",
"half/serde", # TODO: set default when https://github.com/starkat99/half-rs/issues/84 is fixed
]
[dependencies]

View File

@ -7,55 +7,55 @@ use libm::{pow, round};
use rand::{distributions::Standard, Rng, RngCore};
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct DataSerialize<P> {
pub value: Vec<P>,
pub struct DataSerialize<E> {
pub value: Vec<E>,
pub shape: Vec<usize>,
}
#[derive(new, Debug, Clone, PartialEq, Eq)]
pub struct Data<P, const D: usize> {
pub value: Vec<P>,
pub struct Data<E, const D: usize> {
pub value: Vec<E>,
pub shape: Shape<D>,
}
#[derive(Clone, Copy)]
pub enum Distribution<P> {
pub enum Distribution<E> {
Standard,
Bernoulli(f64),
Uniform(P, P),
Uniform(E, E),
Normal(f64, f64),
}
#[derive(new)]
pub struct DistributionSampler<'a, P, R>
pub struct DistributionSampler<'a, E, R>
where
Standard: rand::distributions::Distribution<P>,
P: rand::distributions::uniform::SampleUniform,
Standard: rand::distributions::Distribution<E>,
E: rand::distributions::uniform::SampleUniform,
R: RngCore,
{
kind: DistributionSamplerKind<P>,
kind: DistributionSamplerKind<E>,
rng: &'a mut R,
}
pub enum DistributionSamplerKind<P>
pub enum DistributionSamplerKind<E>
where
Standard: rand::distributions::Distribution<P>,
P: rand::distributions::uniform::SampleUniform,
Standard: rand::distributions::Distribution<E>,
E: rand::distributions::uniform::SampleUniform,
{
Standard(rand::distributions::Standard),
Uniform(rand::distributions::Uniform<P>),
Uniform(rand::distributions::Uniform<E>),
Bernoulli(rand::distributions::Bernoulli),
Normal(rand_distr::Normal<f64>),
}
impl<'a, P, R> DistributionSampler<'a, P, R>
impl<'a, E, R> DistributionSampler<'a, E, R>
where
Standard: rand::distributions::Distribution<P>,
P: rand::distributions::uniform::SampleUniform,
P: Element,
Standard: rand::distributions::Distribution<E>,
E: rand::distributions::uniform::SampleUniform,
E: Element,
R: RngCore,
{
pub fn sample(&mut self) -> P {
pub fn sample(&mut self) -> E {
match &self.kind {
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
@ -71,12 +71,12 @@ where
}
}
impl<P> Distribution<P>
impl<E> Distribution<E>
where
Standard: rand::distributions::Distribution<P>,
P: rand::distributions::uniform::SampleUniform,
Standard: rand::distributions::Distribution<E>,
E: rand::distributions::uniform::SampleUniform,
{
pub fn sampler<R: RngCore>(self, rng: &'_ mut R) -> DistributionSampler<'_, P, R> {
pub fn sampler<R: RngCore>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> {
let kind = match self {
Distribution::Standard => {
DistributionSamplerKind::Standard(rand::distributions::Standard {})
@ -96,23 +96,25 @@ where
}
}
impl<P> Distribution<P>
impl<E> Distribution<E>
where
P: Element,
E: Element,
{
pub fn convert<E: Element>(self) -> Distribution<E> {
pub fn convert<EOther: Element>(self) -> Distribution<EOther> {
match self {
Distribution::Standard => Distribution::Standard,
Distribution::Uniform(a, b) => Distribution::Uniform(E::from_elem(a), E::from_elem(b)),
Distribution::Uniform(a, b) => {
Distribution::Uniform(EOther::from_elem(a), EOther::from_elem(b))
}
Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob),
Distribution::Normal(mean, std) => Distribution::Normal(mean, std),
}
}
}
impl<const D: usize, P: Element> Data<P, D> {
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self.value.into_iter().map(|a| a.elem()).collect();
impl<const D: usize, E: Element> Data<E, D> {
pub fn convert<EOther: Element>(self) -> Data<EOther, D> {
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
Data {
value,
@ -121,9 +123,9 @@ impl<const D: usize, P: Element> Data<P, D> {
}
}
impl<P: Element> DataSerialize<P> {
pub fn convert<E: Element>(self) -> DataSerialize<E> {
let value: Vec<E> = self.value.into_iter().map(|a| a.elem()).collect();
impl<E: Element> DataSerialize<E> {
pub fn convert<EOther: Element>(self) -> DataSerialize<EOther> {
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
DataSerialize {
value,
@ -142,23 +144,23 @@ impl<const D: usize> Data<bool, D> {
}
}
}
impl<P: Element, const D: usize> Data<P, D> {
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution<P>, rng: &mut R) -> Self {
impl<E: Element, const D: usize> Data<E, D> {
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution<E>, rng: &mut R) -> Self {
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);
for _ in 0..num_elements {
data.push(P::random(distribution, rng));
data.push(E::random(distribution, rng));
}
Data::new(data, shape)
}
}
impl<P: core::fmt::Debug, const D: usize> Data<P, D>
impl<E: core::fmt::Debug, const D: usize> Data<E, D>
where
P: Element,
E: Element,
{
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<P, D> {
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<E, D> {
let shape = shape.into();
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);
@ -169,16 +171,16 @@ where
Data::new(data, shape)
}
pub fn zeros_(shape: Shape<D>, _kind: P) -> Data<P, D> {
pub fn zeros_(shape: Shape<D>, _kind: E) -> Data<E, D> {
Self::zeros(shape)
}
}
impl<P: core::fmt::Debug, const D: usize> Data<P, D>
impl<E: core::fmt::Debug, const D: usize> Data<E, D>
where
P: Element,
E: Element,
{
pub fn ones(shape: Shape<D>) -> Data<P, D> {
pub fn ones(shape: Shape<D>) -> Data<E, D> {
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);
@ -188,13 +190,13 @@ where
Data::new(data, shape)
}
pub fn ones_(shape: Shape<D>, _kind: P) -> Data<P, D> {
pub fn ones_(shape: Shape<D>, _kind: E) -> Data<E, D> {
Self::ones(shape)
}
}
impl<P: core::fmt::Debug + Copy, const D: usize> Data<P, D> {
pub fn serialize(&self) -> DataSerialize<P> {
impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
pub fn serialize(&self) -> DataSerialize<E> {
DataSerialize {
value: self.value.clone(),
shape: self.shape.dims.to_vec(),
@ -202,7 +204,7 @@ impl<P: core::fmt::Debug + Copy, const D: usize> Data<P, D> {
}
}
impl<P: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<P, D> {
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E, D> {
pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
assert_eq!(self.shape, other.shape);
@ -246,24 +248,24 @@ impl<const D: usize> Data<usize, D> {
}
}
impl<P: Clone, const D: usize> From<&DataSerialize<P>> for Data<P, D> {
fn from(data: &DataSerialize<P>) -> Self {
impl<E: Clone, const D: usize> From<&DataSerialize<E>> for Data<E, D> {
fn from(data: &DataSerialize<E>) -> Self {
let mut dims = [0; D];
dims[..D].copy_from_slice(&data.shape[..D]);
Data::new(data.value.clone(), Shape::new(dims))
}
}
impl<P, const D: usize> From<DataSerialize<P>> for Data<P, D> {
fn from(data: DataSerialize<P>) -> Self {
impl<E, const D: usize> From<DataSerialize<E>> for Data<E, D> {
fn from(data: DataSerialize<E>) -> Self {
let mut dims = [0; D];
dims[..D].copy_from_slice(&data.shape[..D]);
Data::new(data.value, Shape::new(dims))
}
}
impl<P: core::fmt::Debug + Copy, const A: usize> From<[P; A]> for Data<P, 1> {
fn from(elems: [P; A]) -> Self {
impl<E: core::fmt::Debug + Copy, const A: usize> From<[E; A]> for Data<E, 1> {
fn from(elems: [E; A]) -> Self {
let mut data = Vec::with_capacity(2 * A);
for elem in elems.into_iter() {
data.push(elem);
@ -273,8 +275,8 @@ impl<P: core::fmt::Debug + Copy, const A: usize> From<[P; A]> for Data<P, 1> {
}
}
impl<P: core::fmt::Debug + Copy> From<&[P]> for Data<P, 1> {
fn from(elems: &[P]) -> Self {
impl<E: core::fmt::Debug + Copy> From<&[E]> for Data<E, 1> {
fn from(elems: &[E]) -> Self {
let mut data = Vec::with_capacity(elems.len());
for elem in elems.iter() {
data.push(*elem);
@ -284,8 +286,8 @@ impl<P: core::fmt::Debug + Copy> From<&[P]> for Data<P, 1> {
}
}
impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[P; B]; A]> for Data<P, 2> {
fn from(elems: [[P; B]; A]) -> Self {
impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[E; B]; A]> for Data<E, 2> {
fn from(elems: [[E; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B);
for elem in elems.into_iter().take(A) {
for elem in elem.into_iter().take(B) {
@ -297,10 +299,10 @@ impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[P; B]; A
}
}
impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
From<[[[P; C]; B]; A]> for Data<P, 3>
impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
From<[[[E; C]; B]; A]> for Data<E, 3>
{
fn from(elems: [[[P; C]; B]; A]) -> Self {
fn from(elems: [[[E; C]; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B * C);
for elem in elems.into_iter().take(A) {
@ -316,14 +318,14 @@ impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
}
impl<
P: core::fmt::Debug + Copy,
E: core::fmt::Debug + Copy,
const A: usize,
const B: usize,
const C: usize,
const D: usize,
> From<[[[[P; D]; C]; B]; A]> for Data<P, 4>
> From<[[[[E; D]; C]; B]; A]> for Data<E, 4>
{
fn from(elems: [[[[P; D]; C]; B]; A]) -> Self {
fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
let mut data = Vec::with_capacity(A * B * C * D);
for elem in elems.into_iter().take(A) {
@ -340,11 +342,12 @@ impl<
}
}
impl<P: core::fmt::Debug, const D: usize> core::fmt::Display for Data<P, D> {
impl<E: core::fmt::Debug, const D: usize> core::fmt::Display for Data<E, D> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("{:?}", &self.value).as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,27 +1,29 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::module::State;
use burn_core::module::{State, StateFormat};
use burn_core::tensor::Element;
pub struct FileCheckpointer<P> {
directory: String,
name: String,
num_keep: usize,
format: StateFormat,
_precision: P,
}
impl<P: Element> FileCheckpointer<P> {
pub fn new(directory: &str, name: &str, num_keep: usize) -> Self {
pub fn new(directory: &str, name: &str, num_keep: usize, format: StateFormat) -> Self {
std::fs::create_dir_all(directory).ok();
Self {
directory: directory.to_string(),
name: name.to_string(),
num_keep,
format,
_precision: P::default(),
}
}
fn path_for_epoch(&self, epoch: usize) -> String {
format!("{}/{}-{}.json.gz", self.directory, self.name, epoch)
format!("{}/{}-{}", self.directory, self.name, epoch)
}
}
@ -36,7 +38,7 @@ where
state
.convert::<P>()
.save(&file_path)
.save(&file_path, &self.format)
.map_err(CheckpointerError::IOError)?;
if self.num_keep > epoch {
@ -57,7 +59,8 @@ where
let file_path = self.path_for_epoch(epoch);
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
let state = State::<P>::load(&file_path).map_err(CheckpointerError::StateError)?;
let state =
State::<P>::load(&file_path, &self.format).map_err(CheckpointerError::StateError)?;
Ok(state.convert())
}

View File

@ -6,7 +6,7 @@ use crate::metric::dashboard::cli::CLIDashboardRenderer;
use crate::metric::dashboard::Dashboard;
use crate::metric::{Adaptor, Metric, Numeric};
use crate::AsyncTrainerCallback;
use burn_core::module::ADModule;
use burn_core::module::{ADModule, StateFormat};
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::ADBackend;
use burn_core::tensor::Element;
@ -143,16 +143,19 @@ where
pub fn with_file_checkpointer<P: Element + serde::de::DeserializeOwned + serde::Serialize>(
mut self,
num_keep: usize,
format: StateFormat,
) -> Self {
self.checkpointer_model = Some(Arc::new(FileCheckpointer::<P>::new(
format!("{}/checkpoint", self.directory).as_str(),
"model",
num_keep,
format.clone(),
)));
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::<P>::new(
format!("{}/checkpoint", self.directory).as_str(),
"optim",
num_keep,
format,
)));
self
}

View File

@ -19,6 +19,8 @@ std = [
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
train = ["std", "burn-train"] # Training requires std
msgpack = ["burn-core/msgpack"]
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **

View File

@ -15,10 +15,6 @@ default = []
[dependencies]
burn = {path = "../../burn", default-features = false}
burn-ndarray = {path = "../../burn-ndarray", default-features = false}
# 2.0 supports no_std and serder
bincode = {version = "2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", default-features = false, features = ["alloc", "serde"]}
serde = {workspace = true}
wasm-bindgen = "0.2.84"
@ -26,6 +22,5 @@ wasm-bindgen = "0.2.84"
burn-dataset = {path = "../../burn-dataset"}
[build-dependencies]
bincode = {version = "2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", default-features = false, features = ["alloc", "serde"]}
burn = {path = "../../burn"}
serde = {workspace = true}

View File

@ -27,14 +27,9 @@ makes it possible to build and run the model with the `wasm32-unknown-unknown` t
special system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to
include burn dependencies without `std`).
For this demo, we use trained parameters (`model-4.json.gz`) and model (`model.rs`) from the
For this demo, we use trained parameters (`model.bin`) and model (`model.rs`) from the
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
During the build time `model-4.json.gz` is converted to
[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the
final wasm output. The MNIST model is initialized with trained weights from memory during the
runtime.
The inference API for JavaScript is exposed with the help of
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.
@ -75,8 +70,6 @@ byte file is the model's parameters. The rest of 356,744 bytes contain all the c
There are several planned enhancements in place:
- [#201](https://github.com/burn-rs/burn/issues/201) - Saving model's params in binary format. This
will simplify the inference code.
- [#202](https://github.com/burn-rs/burn/issues/202) - Saving model's params in half-precision and
loading back in full. This can be half the size of the wasm file.
- [#243](https://github.com/burn-rs/burn/issues/243) - New WebGPU backend would allow computation

View File

@ -1,28 +0,0 @@
use std::fs;
use burn::module::State;
use bincode::config;
const GENERATED_FILE_NAME: &str = "mnist_model_state.bincode";
const MODEL_STATE_FILE_NAME: &str = "model-4.json.gz";
/// This build step is responsible for converting JSON serialized to Bincode serilization
/// in order to make the file small and efficient for bundling the binary into wasm code.
///
/// This will be removed once https://github.com/burn-rs/burn/issues/201 is resolved.
fn main() {
println!("cargo:rerun-if-changed={MODEL_STATE_FILE_NAME}");
let config = config::standard();
let path: std::path::PathBuf = [
std::env::var("OUT_DIR").expect("No build target path set"),
GENERATED_FILE_NAME.into(),
]
.iter()
.collect();
let state: State<f32> =
State::load(MODEL_STATE_FILE_NAME).expect(concat!("Model JSON file could not be loaded"));
let serialized =
bincode::serde::encode_to_vec(state, config).expect("Encoding state into bincode failed");
fs::write(path, serialized).expect("Write failed");
}

Binary file not shown.

View File

@ -4,27 +4,14 @@ use burn::module::Module;
use burn::module::State;
use burn_ndarray::NdArrayBackend;
use bincode::{
config::{self, Configuration},
serde::decode_from_slice,
};
pub type Backend = NdArrayBackend<f32>;
const BINCODE_CONF: Configuration = config::standard();
// Bundled bincode serialized model state object
// see https://github.com/bincode-org/bincode and https://doc.rust-lang.org/std/macro.include_bytes.html
static STATE_ENCODED: &[u8] =
include_bytes!(concat!(env!("OUT_DIR"), "/mnist_model_state.bincode"));
static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
/// Builds and loads trained parameters into the model.
pub fn build_and_load_model() -> Model<Backend> {
let model: Model<Backend> = Model::new();
// TODO: fix forward slash to make the paths work in windows
let (state, _len): (State<f32>, usize) =
decode_from_slice(STATE_ENCODED, BINCODE_CONF).expect("Failed to decode state");
let state: State<f32> = State::from_bin(STATE_ENCODED).expect("Failed to decode state");
model
.load(&state)

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use crate::data::MNISTBatcher;
use crate::model::Model;
use burn::module::{Module, StateFormat};
use burn::optim::decay::WeightDecayConfig;
use burn::optim::{Adam, AdamConfig};
use burn::{
@ -64,14 +65,21 @@ pub fn run<B: ADBackend>(device: B::Device) {
.metric_valid_plot(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, optim);
let _model_trained = learner.fit(dataloader_train, dataloader_test);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();
// We save a bin version of the model to be loaded with no_std environement.
model_trained
.state()
.convert::<f32>()
.save(&format!("{ARTIFACT_DIR}/model"), &StateFormat::Bin)
.expect("Failed to save trained model");
}

View File

@ -5,7 +5,7 @@ use crate::{
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
module::Module,
module::{Module, StateFormat},
nn::transformer::TransformerEncoderConfig,
optim::{Sgd, SgdConfig},
tensor::backend::ADBackend,
@ -78,7 +78,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
.metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, optim);
@ -86,9 +86,10 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let model_trained = learner.fit(dataloader_train, dataloader_test);
config.save(&format!("{artifact_dir}/config.json")).unwrap();
model_trained
.state()
.convert::<f32>()
.save(&format!("{artifact_dir}/model.json.gz"))
.convert::<burn::tensor::f16>()
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
.unwrap();
}

View File

@ -2,7 +2,6 @@ use crate::{
data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer},
model::{TextGenerationModel, TextGenerationModelConfig},
};
use burn::data::dataset::transform::SamplerDataset;
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
@ -15,6 +14,7 @@ use burn::{
LearnerBuilder,
},
};
use burn::{data::dataset::transform::SamplerDataset, module::StateFormat};
use std::sync::Arc;
#[derive(Config)]
@ -75,7 +75,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
.metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
.devices(vec![device])
.grads_accumulation(16)
.num_epochs(config.num_epochs)
@ -84,9 +84,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let model_trained = learner.fit(dataloader_train, dataloader_test);
config.save(&format!("{artifact_dir}/config.json")).unwrap();
model_trained
.state()
.convert::<f32>()
.save(&format!("{artifact_dir}/model.json.gz"))
.convert::<burn::tensor::f16>()
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
.unwrap();
}