Feat/record (#262)

This commit is contained in:
Nathaniel Simard 2023-04-02 10:09:29 -04:00 committed by GitHub
parent 4e9e6d2706
commit 73f6d1916b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 645 additions and 339 deletions

View File

@ -13,6 +13,7 @@ pub mod optim;
pub mod module;
pub mod nn;
pub mod record;
pub mod tensor;
extern crate alloc;

View File

@ -1,7 +1,6 @@
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use super::ParamId;
@ -28,24 +27,7 @@ pub enum State<E> {
#[derive(Debug)]
pub enum StateError {
InvalidFormat(String),
FileNotFound(String),
}
/// All supported format.
///
/// # Notes
///
/// The default file format is compressed bincode for the smallest file size possible.
/// For `no_std` environments, you should use (StateFormat::Bin)[StateFormat::Bin] since compression isn't supported.
/// However, the bincode format alone is smaller than compressed `json` or `msgpack`.
#[derive(Default, Clone)]
pub enum StateFormat {
#[default]
BinGz,
Bin,
JsonGz,
#[cfg(feature = "msgpack")]
MpkGz,
Unknown(String),
}
impl core::fmt::Display for StateError {
@ -56,8 +38,8 @@ impl core::fmt::Display for StateError {
Self::InvalidFormat(err) => {
message += format!("Invalid format: {err}").as_str();
}
Self::FileNotFound(err) => {
message += format!("File not found: {err}").as_str();
Self::Unknown(err) => {
message += format!("Unknown error: {err}").as_str();
}
};
@ -122,170 +104,6 @@ impl<E: Element> State<E> {
}
}
impl<E: Element> State<E>
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
pub fn to_bin(&self) -> Result<Vec<u8>, StateError> {
Ok(bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap())
}
pub fn from_bin(data: &[u8]) -> Result<Self, StateError> {
let state = bincode::serde::decode_borrowed_from_slice(data, Self::bin_config()).unwrap();
Ok(state)
}
fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}
}
#[cfg(feature = "std")]
mod std_enabled {
use super::*;
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use std::{fs::File, path::Path};
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
impl std::error::Error for StateError {}
macro_rules! str2reader {
(
$file:expr,
$ext:expr
) => {{
let path_ref = &format!("{}.{}", $file, $ext);
let path = Path::new(path_ref);
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))
}};
}
macro_rules! str2writer {
(
$file:expr,
$ext:expr
) => {{
let path_ref = &format!("{}.{}", $file, $ext);
let path = Path::new(path_ref);
if path.exists() {
log::info!("File exists, replacing");
std::fs::remove_file(path).unwrap();
}
File::create(path)
}};
}
impl<E: Element> State<E>
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
/// Save the state to the provided file path using the given [StateFormat](StateFormat).
///
/// # Notes
///
/// The file extension will be added automatically depending on the state format.
pub fn save(self, file: &str, format: &StateFormat) -> std::io::Result<()> {
match format {
StateFormat::BinGz => self.save_bingz(file),
StateFormat::Bin => self.save_bin(file),
StateFormat::JsonGz => self.save_jsongz(file),
#[cfg(feature = "msgpack")]
StateFormat::MpkGz => self.save_mpkgz(file),
}
}
/// Load the state from the provided file path using the given [StateFormat](StateFormat).
///
/// # Notes
///
/// The file extension will be added automatically depending on the state format.
pub fn load(file: &str, format: &StateFormat) -> Result<Self, StateError> {
match format {
StateFormat::BinGz => Self::load_bingz(file),
StateFormat::Bin => Self::load_bin(file),
StateFormat::JsonGz => Self::load_jsongz(file),
#[cfg(feature = "msgpack")]
StateFormat::MpkGz => Self::load_mpkgz(file),
}
}
fn save_jsongz(self, file: &str) -> std::io::Result<()> {
let writer = str2writer!(file, "json.gz")?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &self).unwrap();
Ok(())
}
fn load_jsongz(file: &str) -> Result<Self, StateError> {
let reader = str2reader!(file, "json.gz")?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader).unwrap();
Ok(state)
}
#[cfg(feature = "msgpack")]
fn save_mpkgz(self, file: &str) -> std::io::Result<()> {
let writer = str2writer!(file, "mpk.gz")?;
let mut writer = GzEncoder::new(writer, Compression::default());
rmp_serde::encode::write(&mut writer, &self).unwrap();
Ok(())
}
#[cfg(feature = "msgpack")]
fn load_mpkgz(file: &str) -> Result<Self, StateError> {
let reader = str2reader!(file, "mpk.gz")?;
let reader = GzDecoder::new(reader);
let state = rmp_serde::decode::from_read(reader).unwrap();
Ok(state)
}
fn save_bingz(self, file: &str) -> std::io::Result<()> {
let config = Self::bin_config();
let writer = str2writer!(file, "bin.gz")?;
let mut writer = GzEncoder::new(writer, Compression::default());
bincode::serde::encode_into_std_write(&self, &mut writer, config).unwrap();
Ok(())
}
fn load_bingz(file: &str) -> Result<Self, StateError> {
let reader = str2reader!(file, "bin.gz")?;
let mut reader = GzDecoder::new(reader);
let state =
bincode::serde::decode_from_std_read(&mut reader, Self::bin_config()).unwrap();
Ok(state)
}
fn save_bin(self, file: &str) -> std::io::Result<()> {
let buf = bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap();
let mut writer = str2writer!(file, "bin")?;
std::io::Write::write_all(&mut writer, &buf).unwrap();
Ok(())
}
fn load_bin(file: &str) -> Result<Self, StateError> {
let mut reader = str2reader!(file, "bin")?;
let mut buf = Vec::new();
std::io::Read::read_to_end(&mut reader, &mut buf).unwrap();
let state =
bincode::serde::decode_borrowed_from_slice(&buf, Self::bin_config()).unwrap();
Ok(state)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -313,6 +131,7 @@ mod tests {
let params_before_2 = list_param_ids(&model_2);
let state = model_1.state();
model_2 = model_2.load(&state).unwrap();
let params_after_2 = list_param_ids(&model_2);
@ -320,80 +139,7 @@ mod tests {
assert_eq!(params_before_1, params_after_2);
}
#[test]
fn test_from_to_binary() {
let model_1 = create_model();
let model_2 = create_model();
let params_before_1 = list_param_ids(&model_1);
let params_before_2 = list_param_ids(&model_2);
// To & From Bytes
let bytes = model_1.state().to_bin().unwrap();
let model_2 = model_2.load(&State::from_bin(&bytes).unwrap()).unwrap();
// Verify.
let params_after_2 = list_param_ids(&model_2);
assert_ne!(params_before_1, params_before_2);
assert_eq!(params_before_1, params_after_2);
}
pub fn create_model() -> nn::Linear<TestBackend> {
nn::LinearConfig::new(32, 32).with_bias(true).init()
}
}
#[cfg(all(test, feature = "std"))]
mod tests_save_load {
use super::tests::create_model;
use super::*;
use crate::module::Module;
static FILE_PATH: &str = "/tmp/test_state";
#[test]
fn test_can_save_and_load_from_file_jsongz_format() {
test_can_save_and_load_from_file(StateFormat::JsonGz)
}
#[test]
fn test_can_save_and_load_from_file_bin_format() {
test_can_save_and_load_from_file(StateFormat::Bin)
}
#[test]
fn test_can_save_and_load_from_file_bingz_format() {
test_can_save_and_load_from_file(StateFormat::BinGz)
}
#[cfg(feature = "msgpack")]
#[test]
fn test_can_save_and_load_from_file_mpkgz_format() {
test_can_save_and_load_from_file(StateFormat::MpkGz)
}
#[test]
fn test_from_bin_on_disk() {
let model = create_model();
model
.state()
.save("/tmp/model_compare", &StateFormat::Bin)
.unwrap();
let bytes = std::fs::read("/tmp/model_compare.bin").unwrap();
let state = State::from_bin(&bytes).unwrap();
assert_eq!(state, model.state());
}
fn test_can_save_and_load_from_file(format: StateFormat) {
let model_before = create_model();
let state_before = model_before.state();
state_before.clone().save(FILE_PATH, &format).unwrap();
let model_after = create_model()
.load(&State::load(FILE_PATH, &format).unwrap())
.unwrap();
let state_after = model_after.state();
assert_eq!(state_before, state_after);
}
}

View File

@ -0,0 +1,149 @@
use super::{RecordSettings, Recorder, RecorderError};
use crate::alloc::string::ToString;
use alloc::format;
use alloc::string::String;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
/// Trait to define a family of types which can be recorded using any [settings](RecordSettings).
pub trait Record: Send + Sync {
type Item<S: RecordSettings>;
/// Convert the current record into the corresponding item that follows the given [settings](RecordSettings).
fn into_item<S: RecordSettings>(self) -> Self::Item<S>;
/// Convert the given item into a record.
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self;
/// Record using the given [settings](RecordSettings).
fn record<S>(self, args: RecordArgs<S>) -> RecordOutputResult<S>
where
Self: Sized,
S: RecordSettings,
Self::Item<S>: Serialize + DeserializeOwned,
{
let metadata = BurnMetadata::new(
core::any::type_name::<S::FloatElem>().to_string(),
core::any::type_name::<S::IntElem>().to_string(),
core::any::type_name::<S::Recorder>().to_string(),
env!("CARGO_PKG_VERSION").to_string(),
format!("{:?}", S::default()),
);
let item = self.into_item::<S>();
let record = BurnRecord::new(item, metadata);
RecorderType::<S>::record(record, args)
}
/// Load the record using the given [settings](RecordSettings).
fn load<S>(args: LoadArgs<S>) -> Result<Self, RecorderError>
where
Self: Sized,
S: RecordSettings,
Self::Item<S>: Serialize + DeserializeOwned,
{
let record: BurnRecord<Self::Item<S>> =
RecorderType::<S>::load(args.clone()).map_err(|err| {
let message = match err {
RecorderError::FileNotFound(_) => return err,
RecorderError::Unknown(message) => message,
};
let record = RecorderType::<S>::load::<BurnRecordNoItem>(args);
let message = match record {
Ok(record) => format!(
"Unable to load record with settings {:?}, found metadata {:?}, err: {:?}",
S::default(),
record.metadata,
message
),
Err(_err) => message,
};
RecorderError::Unknown(message)
})?;
Ok(Self::from_item(record.item))
}
}
/// Record arguments for the given settings.
pub type RecordArgs<S> = <<S as RecordSettings>::Recorder as Recorder>::RecordArgs;
/// Record loading arguments for the given settings.
pub type LoadArgs<S> = <<S as RecordSettings>::Recorder as Recorder>::LoadArgs;
/// Record output result for the given settings.
pub type RecordOutputResult<S> =
Result<<<S as RecordSettings>::Recorder as Recorder>::RecordOutput, RecorderError>;
/// Recorder for the given settings.
pub type RecorderType<S> = <S as RecordSettings>::Recorder;
#[derive(new, Debug, Serialize, Deserialize)]
struct BurnMetadata {
float: String,
int: String,
format: String,
version: String,
settings: String,
}
#[derive(new, Serialize, Deserialize)]
struct BurnRecord<I> {
item: I,
metadata: BurnMetadata,
}
#[derive(new, Serialize, Deserialize)]
struct BurnRecordNoItem {
metadata: BurnMetadata,
}
#[cfg(all(test, feature = "std"))]
mod tests {
static FILE_PATH: &str = "/tmp/burn_test_record";
use core::marker::PhantomData;
use super::*;
use crate::record::FileJsonGzRecorder;
use burn_tensor::{Element, ElementConversion};
#[test]
#[should_panic]
fn err_when_invalid_object() {
#[derive(Debug, Default)]
pub struct TestSettings<F> {
float: PhantomData<F>,
}
impl<F: Element + Serialize + DeserializeOwned> RecordSettings for TestSettings<F> {
type FloatElem = F;
type IntElem = i32;
type Recorder = FileJsonGzRecorder;
}
#[derive(new, Serialize, Deserialize)]
struct Item<S: RecordSettings> {
value: S::FloatElem,
}
impl<D: RecordSettings> Record for Item<D> {
type Item<S: RecordSettings> = Item<S>;
fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
Item {
value: self.value.elem(),
}
}
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
Item {
value: item.value.elem(),
}
}
}
let item = Item::<TestSettings<f32>>::new(16.elem());
// Serialize in f32.
item.record::<TestSettings<f32>>(FILE_PATH.into()).unwrap();
// Can't deserialize u8 into f32.
Item::<TestSettings<f32>>::load::<TestSettings<u8>>(FILE_PATH.into()).unwrap();
}
}

View File

@ -0,0 +1,217 @@
use super::{bin_config, Recorder, RecorderError};
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use serde::{de::DeserializeOwned, Serialize};
use std::{fs::File, path::PathBuf};
/// Recorder trait specialized to save and load data to and from files.
pub trait FileRecorder:
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
{
}
/// File recorder using the [bincode format](bincode).
pub struct FileBinRecorder;
/// File recorder using the [bincode format](bincode) compressed with gzip.
pub struct FileBinGzRecorder;
/// File recorder using the json format compressed with gzip.
pub struct FileJsonGzRecorder;
#[cfg(feature = "msgpack")]
/// File recorder using the [message pack](rmp_serde) format compressed with gzip.
pub struct FileMpkGzRecorder;
impl FileRecorder for FileBinGzRecorder {}
impl FileRecorder for FileBinRecorder {}
impl FileRecorder for FileJsonGzRecorder {}
#[cfg(feature = "msgpack")]
impl FileRecorder for FileMpkGzRecorder {}
macro_rules! str2reader {
(
$file:expr,
$ext:expr
) => {{
$file.set_extension($ext);
let path = $file.as_path();
File::open(path).map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
})
}};
}
macro_rules! str2writer {
(
$file:expr,
$ext:expr
) => {{
$file.set_extension($ext);
let path = $file.as_path();
if path.exists() {
log::info!("File exists, replacing");
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
}
File::create(path).map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
})
}};
}
impl Recorder for FileBinGzRecorder {
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn record<Obj: Serialize + DeserializeOwned>(
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let config = bin_config();
let writer = str2writer!(file, "bin.gz")?;
let mut writer = GzEncoder::new(writer, Compression::default());
bincode::serde::encode_into_std_write(&obj, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "bin.gz")?;
let mut reader = GzDecoder::new(reader);
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl Recorder for FileBinRecorder {
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn record<Obj: Serialize + DeserializeOwned>(
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let config = bin_config();
let mut writer = str2writer!(file, "bin")?;
bincode::serde::encode_into_std_write(&obj, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let mut reader = str2reader!(file, "bin")?;
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
impl Recorder for FileJsonGzRecorder {
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn record<Obj: Serialize + DeserializeOwned>(
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let writer = str2writer!(file, "json.gz")?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &obj)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "json.gz")?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
#[cfg(feature = "msgpack")]
impl Recorder for FileMpkGzRecorder {
type SaveArgs = PathBuf;
type SaveOutput = ();
type LoadArgs = PathBuf;
fn save<Obj: Serialize + DeserializeOwned>(
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let writer = str2writer!(file, "mpk.gz")?;
let mut writer = GzEncoder::new(writer, Compression::default());
rmp_serde::encode::write(&mut writer, &obj)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "mpk.gz")?;
let reader = GzDecoder::new(reader);
let state = rmp_serde::decode::from_read(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{module::Module, nn, TestBackend};
static FILE_PATH: &str = "/tmp/burn_test_file_recorder";
#[test]
fn test_can_save_and_load_jsongz_format() {
test_can_save_and_load::<FileJsonGzRecorder>()
}
#[test]
fn test_can_save_and_load_bin_format() {
test_can_save_and_load::<FileBinRecorder>()
}
#[test]
fn test_can_save_and_load_bingz_format() {
test_can_save_and_load::<FileBinGzRecorder>()
}
#[cfg(feature = "msgpack")]
#[test]
fn test_can_save_and_load_mpkgz_format() {
test_can_save_and_load::<FileMpkGzRecorder>()
}
fn test_can_save_and_load<Recorder: FileRecorder>() {
let model_before = create_model();
let state_before = model_before.state();
Recorder::record(state_before.clone(), FILE_PATH.into()).unwrap();
let model_after = create_model()
.load(&Recorder::load(FILE_PATH.into()).unwrap())
.unwrap();
let state_after = model_after.state();
assert_eq!(state_before, state_after);
}
pub fn create_model() -> nn::Linear<TestBackend> {
nn::LinearConfig::new(32, 32).with_bias(true).init()
}
}

View File

@ -0,0 +1,65 @@
use super::{bin_config, Recorder, RecorderError};
use alloc::vec::Vec;
use serde::{de::DeserializeOwned, Serialize};
/// Recorder trait specialized to save and load data to and from bytes.
///
/// # Notes
///
/// This is especialy useful in no_std environment where weights are stored directly in
/// compiled binaries.
pub trait InMemoryRecorder:
Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
{
}
/// In memory recorder using the [bincode format](bincode).
pub struct InMemoryBinRecorder;
impl InMemoryRecorder for InMemoryBinRecorder {}
impl Recorder for InMemoryBinRecorder {
type RecordArgs = ();
type RecordOutput = Vec<u8>;
type LoadArgs = Vec<u8>;
fn record<Obj: Serialize + DeserializeOwned>(
obj: Obj,
_args: Self::RecordArgs,
) -> Result<Vec<u8>, RecorderError> {
Ok(bincode::serde::encode_to_vec(obj, bin_config()).unwrap())
}
fn load<Obj: Serialize + DeserializeOwned>(args: Self::LoadArgs) -> Result<Obj, RecorderError> {
let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap();
Ok(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{module::Module, nn, TestBackend};
#[test]
fn test_can_save_and_load_bin_format() {
test_can_save_and_load::<InMemoryBinRecorder>()
}
fn test_can_save_and_load<Recorder: InMemoryRecorder>() {
let model_before = create_model();
let state_before = model_before.state();
let bytes = Recorder::record(state_before.clone(), ()).unwrap();
let model_after = create_model()
.load(&Recorder::load(bytes).unwrap())
.unwrap();
let state_after = model_after.state();
assert_eq!(state_before, state_after);
}
pub fn create_model() -> nn::Linear<TestBackend> {
nn::LinearConfig::new(32, 32).with_bias(true).init()
}
}

View File

@ -0,0 +1,16 @@
mod state;
mod base;
mod memory;
mod recorder;
mod settings;
pub use base::*;
pub use memory::*;
pub use recorder::*;
pub use settings::*;
#[cfg(feature = "std")]
mod file;
#[cfg(feature = "std")]
pub use file::*;

View File

@ -0,0 +1,42 @@
use alloc::format;
use alloc::string::String;
use serde::{de::DeserializeOwned, Serialize};
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
pub trait Recorder: Send + Sync {
/// Arguments used to record objects.
type RecordArgs: Clone;
/// Record output type.
type RecordOutput;
/// Arguments used to load recorded objects.
type LoadArgs: Clone;
fn record<Item: Serialize + DeserializeOwned>(
item: Item,
args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError>;
/// Load an item from the given arguments.
fn load<Item: Serialize + DeserializeOwned>(
args: Self::LoadArgs,
) -> Result<Item, RecorderError>;
}
#[derive(Debug)]
pub enum RecorderError {
FileNotFound(String),
Unknown(String),
}
impl core::fmt::Display for RecorderError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("{self:?}").as_str())
}
}
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
#[cfg(feature = "std")]
impl std::error::Error for RecorderError {}
pub(crate) fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}

View File

@ -0,0 +1,44 @@
use super::Recorder;
use burn_tensor::Element;
use serde::{de::DeserializeOwned, Serialize};
pub trait RecordSettings: Send + Sync + core::fmt::Debug + core::default::Default {
type FloatElem: Element + Serialize + DeserializeOwned;
type IntElem: Element + Serialize + DeserializeOwned;
type Recorder: Recorder;
}
/// Default record settings.
#[derive(Debug, Default)]
pub struct DefaultRecordSettings;
/// Training settings compatible with no-std inference.
#[derive(Debug, Default)]
pub struct NoStdTrainingRecordSettings;
/// Inference settings compatible with no-std.
#[derive(Debug, Default)]
pub struct NoStdInferenceRecordSettings;
impl RecordSettings for DefaultRecordSettings {
#[cfg(feature = "std")]
type FloatElem = half::f16;
#[cfg(not(feature = "std"))]
type FloatElem = f32;
type IntElem = i16;
#[cfg(feature = "std")]
type Recorder = crate::record::FileBinGzRecorder;
#[cfg(not(feature = "std"))]
type Recorder = crate::record::InMemoryBinRecorder;
}
#[cfg(feature = "std")]
impl RecordSettings for NoStdTrainingRecordSettings {
type FloatElem = f32;
type IntElem = i32;
type Recorder = crate::record::FileBinRecorder;
}
impl RecordSettings for NoStdInferenceRecordSettings {
type FloatElem = f32;
type IntElem = i32;
type Recorder = crate::record::InMemoryBinRecorder;
}

View File

@ -0,0 +1,15 @@
use super::{Record, RecordSettings};
use crate::module::State;
use burn_tensor::Element;
impl<T: Element> Record for State<T> {
type Item<S: RecordSettings> = State<S::FloatElem>;
fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
self.convert::<S::FloatElem>()
}
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
item.convert()
}
}

View File

@ -1,20 +1,19 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::module::State;
use burn_core::tensor::Element;
use burn_core::record::Record;
use std::sync::{mpsc, Arc};
enum Message<E> {
Save(usize, State<E>),
enum Message<R> {
Save(usize, R),
End,
}
#[derive(new)]
struct CheckpointerThread<T> {
checkpointer: Arc<dyn Checkpointer<T> + Send + Sync>,
receiver: mpsc::Receiver<Message<T>>,
struct CheckpointerThread<R> {
checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>,
receiver: mpsc::Receiver<Message<R>>,
}
impl<T> CheckpointerThread<T> {
impl<R: Record> CheckpointerThread<R> {
fn run(self) {
for item in self.receiver.iter() {
match item {
@ -33,8 +32,8 @@ pub struct AsyncCheckpointer<E> {
handler: Option<std::thread::JoinHandle<()>>,
}
impl<E: Element + 'static> AsyncCheckpointer<E> {
pub fn new(checkpointer: Arc<dyn Checkpointer<E> + Send + Sync>) -> Self {
impl<R: Record + 'static> AsyncCheckpointer<R> {
pub fn new(checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>) -> Self {
// Only on checkpoint can be done in advance.
let (sender, receiver) = mpsc::sync_channel(0);
let thread = CheckpointerThread::new(checkpointer.clone(), receiver);
@ -48,17 +47,17 @@ impl<E: Element + 'static> AsyncCheckpointer<E> {
}
}
impl<E> Checkpointer<E> for AsyncCheckpointer<E>
impl<R> Checkpointer<R> for AsyncCheckpointer<R>
where
E: Element + Sync + 'static,
R: Record + 'static,
{
fn save(&self, epoch: usize, state: State<E>) -> Result<(), CheckpointerError> {
self.sender.send(Message::Save(epoch, state)).unwrap();
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
self.sender.send(Message::Save(epoch, record)).unwrap();
Ok(())
}
fn restore(&self, epoch: usize) -> Result<State<E>, CheckpointerError> {
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
self.checkpointer.restore(epoch)
}
}

View File

@ -1,12 +1,17 @@
use burn_core::module::{State, StateError};
use burn_core::{
module::StateError,
record::{Record, RecorderError},
};
#[derive(Debug)]
pub enum CheckpointerError {
IOError(std::io::Error),
RecorderError(RecorderError),
StateError(StateError),
Unknown(String),
}
pub trait Checkpointer<E> {
fn save(&self, epoch: usize, state: State<E>) -> Result<(), CheckpointerError>;
fn restore(&self, epoch: usize) -> Result<State<E>, CheckpointerError>;
pub trait Checkpointer<R: Record> {
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError>;
}

View File

@ -1,25 +1,33 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::module::{State, StateFormat};
use burn_core::tensor::Element;
use std::marker::PhantomData;
pub struct FileCheckpointer<P> {
use super::{Checkpointer, CheckpointerError};
use burn_core::record::{FileRecorder, Record, RecordSettings};
use serde::{de::DeserializeOwned, Serialize};
pub struct FileCheckpointer<S>
where
S: RecordSettings,
S::Recorder: FileRecorder,
{
directory: String,
name: String,
num_keep: usize,
format: StateFormat,
_precision: P,
settings: PhantomData<S>,
}
impl<P: Element> FileCheckpointer<P> {
pub fn new(directory: &str, name: &str, num_keep: usize, format: StateFormat) -> Self {
impl<S> FileCheckpointer<S>
where
S: RecordSettings,
S::Recorder: FileRecorder,
{
pub fn new(directory: &str, name: &str, num_keep: usize) -> Self {
std::fs::create_dir_all(directory).ok();
Self {
directory: directory.to_string(),
name: name.to_string(),
num_keep,
format,
_precision: P::default(),
settings: PhantomData::default(),
}
}
fn path_for_epoch(&self, epoch: usize) -> String {
@ -27,19 +35,20 @@ impl<P: Element> FileCheckpointer<P> {
}
}
impl<E, P> Checkpointer<E> for FileCheckpointer<P>
impl<R, S> Checkpointer<R> for FileCheckpointer<S>
where
P: serde::Serialize + serde::de::DeserializeOwned + Element,
E: Element,
R: Record,
S: RecordSettings,
S::Recorder: FileRecorder,
R::Item<S>: Serialize + DeserializeOwned,
{
fn save(&self, epoch: usize, state: State<E>) -> Result<(), CheckpointerError> {
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Saving checkpoint {} to {}", epoch, file_path);
state
.convert::<P>()
.save(&file_path, &self.format)
.map_err(CheckpointerError::IOError)?;
record
.record(file_path.into())
.map_err(CheckpointerError::RecorderError)?;
if self.num_keep > epoch {
return Ok(());
@ -55,13 +64,11 @@ where
Ok(())
}
fn restore(&self, epoch: usize) -> Result<State<E>, CheckpointerError> {
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
let record = R::load(file_path.into()).map_err(CheckpointerError::RecorderError)?;
let state =
State::<P>::load(&file_path, &self.format).map_err(CheckpointerError::StateError)?;
Ok(state.convert())
Ok(record)
}
}

View File

@ -1,6 +1,6 @@
use crate::checkpoint::Checkpointer;
use crate::LearnerCallback;
use burn_core::module::ADModule;
use burn_core::module::{ADModule, State};
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::{ADBackend, Backend};
@ -24,8 +24,8 @@ where
pub(super) devices: Vec<B::Device>,
}
type CheckpointModel<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
type CheckpointModel<B> = Option<Box<dyn Checkpointer<State<<B as Backend>::FloatElem>>>>;
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<State<<B as Backend>::FloatElem>>>>;
impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
where

View File

@ -6,10 +6,10 @@ use crate::metric::dashboard::cli::CLIDashboardRenderer;
use crate::metric::dashboard::Dashboard;
use crate::metric::{Adaptor, Metric, Numeric};
use crate::AsyncTrainerCallback;
use burn_core::module::{ADModule, StateFormat};
use burn_core::module::{ADModule, State};
use burn_core::optim::Optimizer;
use burn_core::record::{FileRecorder, RecordSettings};
use burn_core::tensor::backend::ADBackend;
use burn_core::tensor::Element;
use std::sync::Arc;
/// Struct to configure and create a [learner](Learner).
@ -20,8 +20,8 @@ where
B: ADBackend,
{
dashboard: Dashboard<T, V>,
checkpointer_model: Option<Arc<dyn Checkpointer<B::FloatElem> + Send + Sync>>,
checkpointer_optimizer: Option<Arc<dyn Checkpointer<B::FloatElem> + Send + Sync>>,
checkpointer_model: Option<Arc<dyn Checkpointer<State<B::FloatElem>> + Send + Sync>>,
checkpointer_optimizer: Option<Arc<dyn Checkpointer<State<B::FloatElem>> + Send + Sync>>,
num_epochs: usize,
checkpoint: Option<usize>,
directory: String,
@ -140,22 +140,20 @@ where
/// The number of checkpoints to be keep should be set to a minimum of two to be safe, since
/// they are saved and deleted asynchronously and a crash during training might make a
/// checkpoint non-usable.
pub fn with_file_checkpointer<P: Element + serde::de::DeserializeOwned + serde::Serialize>(
mut self,
num_keep: usize,
format: StateFormat,
) -> Self {
self.checkpointer_model = Some(Arc::new(FileCheckpointer::<P>::new(
pub fn with_file_checkpointer<S>(mut self, num_keep: usize) -> Self
where
S: RecordSettings + 'static,
S::Recorder: FileRecorder,
{
self.checkpointer_model = Some(Arc::new(FileCheckpointer::<S>::new(
format!("{}/checkpoint", self.directory).as_str(),
"model",
num_keep,
format.clone(),
)));
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::<P>::new(
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::<S>::new(
format!("{}/checkpoint", self.directory).as_str(),
"optim",
num_keep,
format,
)));
self
}
@ -172,7 +170,7 @@ where
let create_checkpointer = |checkpointer| match checkpointer {
Some(checkpointer) => {
let checkpointer: Box<dyn Checkpointer<B::FloatElem>> =
let checkpointer: Box<dyn Checkpointer<State<B::FloatElem>>> =
Box::new(AsyncCheckpointer::new(checkpointer));
Some(checkpointer)
}

View File

@ -1,7 +1,7 @@
use crate::model::Model;
use burn::module::Module;
use burn::module::State;
use burn::record::NoStdInferenceRecordSettings;
use burn::record::Record;
use burn_ndarray::NdArrayBackend;
pub type Backend = NdArrayBackend<f32>;
@ -11,7 +11,8 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
/// Builds and loads trained parameters into the model.
pub fn build_and_load_model() -> Model<Backend> {
let model: Model<Backend> = Model::new();
let state: State<f32> = State::from_bin(STATE_ENCODED).expect("Failed to decode state");
let state = Record::load::<NoStdInferenceRecordSettings>(STATE_ENCODED.to_vec())
.expect("Failed to decode state");
model
.load(&state)

View File

@ -3,9 +3,10 @@ use std::sync::Arc;
use crate::data::MNISTBatcher;
use crate::model::Model;
use burn::module::{Module, StateFormat};
use burn::module::Module;
use burn::optim::decay::WeightDecayConfig;
use burn::optim::{Adam, AdamConfig};
use burn::record::{DefaultRecordSettings, NoStdTrainingRecordSettings, Record};
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
@ -65,7 +66,7 @@ pub fn run<B: ADBackend>(device: B::Device) {
.metric_valid_plot(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
.with_file_checkpointer::<DefaultRecordSettings>(2)
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, optim);
@ -79,7 +80,6 @@ pub fn run<B: ADBackend>(device: B::Device) {
// We save a bin version of the model to be loaded with no_std environement.
model_trained
.state()
.convert::<f32>()
.save(&format!("{ARTIFACT_DIR}/model"), &StateFormat::Bin)
.record::<NoStdTrainingRecordSettings>(format!("{ARTIFACT_DIR}/model").into())
.expect("Failed to save trained model");
}

View File

@ -3,8 +3,9 @@ use std::sync::Arc;
use burn::{
config::Config,
data::dataloader::batcher::Batcher,
module::{Module, State, StateFormat},
tensor::backend::Backend,
module::{Module, State},
record::{DefaultRecordSettings, Record},
tensor::{backend::Backend, f16},
};
use crate::{
@ -38,10 +39,8 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
.init::<B>();
println!("Loading weights ...");
let state = State::<burn::tensor::f16>::load(
format!("{artifact_dir}/model").as_str(),
&StateFormat::default(),
)
let state: State<f16> =
Record::load::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
.expect("Trained model weights");
let model = model.load(&state.convert()).expect("Can load weights");
let model = model.to_device(&device);

View File

@ -5,9 +5,10 @@ use crate::{
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
module::{Module, StateFormat},
module::Module,
nn::transformer::TransformerEncoderConfig,
optim::{Sgd, SgdConfig},
record::{DefaultRecordSettings, Record},
tensor::backend::ADBackend,
train::{
metric::{AccuracyMetric, CUDAMetric, LossMetric},
@ -78,7 +79,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
.metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
.with_file_checkpointer::<DefaultRecordSettings>(2)
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, optim);
@ -89,7 +90,6 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
model_trained
.state()
.convert::<burn::tensor::f16>()
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
.record::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
.unwrap();
}

View File

@ -14,7 +14,10 @@ use burn::{
LearnerBuilder,
},
};
use burn::{data::dataset::transform::SamplerDataset, module::StateFormat};
use burn::{
data::dataset::transform::SamplerDataset,
record::{DefaultRecordSettings, Record},
};
use std::sync::Arc;
#[derive(Config)]
@ -76,7 +79,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
.metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
.with_file_checkpointer::<DefaultRecordSettings>(2)
.devices(vec![device])
.grads_accumulation(16)
.num_epochs(config.num_epochs)
@ -88,7 +91,6 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
model_trained
.state()
.convert::<burn::tensor::f16>()
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
.record::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
.unwrap();
}