mirror of https://github.com/tracel-ai/burn.git
Feat/record (#262)
This commit is contained in:
parent
4e9e6d2706
commit
73f6d1916b
|
@ -13,6 +13,7 @@ pub mod optim;
|
|||
|
||||
pub mod module;
|
||||
pub mod nn;
|
||||
pub mod record;
|
||||
pub mod tensor;
|
||||
|
||||
extern crate alloc;
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use alloc::{
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use super::ParamId;
|
||||
|
@ -28,24 +27,7 @@ pub enum State<E> {
|
|||
#[derive(Debug)]
|
||||
pub enum StateError {
|
||||
InvalidFormat(String),
|
||||
FileNotFound(String),
|
||||
}
|
||||
|
||||
/// All supported format.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The default file format is compressed bincode for the smallest file size possible.
|
||||
/// For `no_std` environments, you should use (StateFormat::Bin)[StateFormat::Bin] since compression isn't supported.
|
||||
/// However, the bincode format alone is smaller than compressed `json` or `msgpack`.
|
||||
#[derive(Default, Clone)]
|
||||
pub enum StateFormat {
|
||||
#[default]
|
||||
BinGz,
|
||||
Bin,
|
||||
JsonGz,
|
||||
#[cfg(feature = "msgpack")]
|
||||
MpkGz,
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for StateError {
|
||||
|
@ -56,8 +38,8 @@ impl core::fmt::Display for StateError {
|
|||
Self::InvalidFormat(err) => {
|
||||
message += format!("Invalid format: {err}").as_str();
|
||||
}
|
||||
Self::FileNotFound(err) => {
|
||||
message += format!("File not found: {err}").as_str();
|
||||
Self::Unknown(err) => {
|
||||
message += format!("Unknown error: {err}").as_str();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -122,170 +104,6 @@ impl<E: Element> State<E> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E: Element> State<E>
|
||||
where
|
||||
E: serde::de::DeserializeOwned,
|
||||
E: serde::Serialize,
|
||||
{
|
||||
pub fn to_bin(&self) -> Result<Vec<u8>, StateError> {
|
||||
Ok(bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap())
|
||||
}
|
||||
|
||||
pub fn from_bin(data: &[u8]) -> Result<Self, StateError> {
|
||||
let state = bincode::serde::decode_borrowed_from_slice(data, Self::bin_config()).unwrap();
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn bin_config() -> bincode::config::Configuration {
|
||||
bincode::config::standard()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod std_enabled {
|
||||
use super::*;
|
||||
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||
use std::{fs::File, path::Path};
|
||||
|
||||
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
|
||||
impl std::error::Error for StateError {}
|
||||
|
||||
macro_rules! str2reader {
|
||||
(
|
||||
$file:expr,
|
||||
$ext:expr
|
||||
) => {{
|
||||
let path_ref = &format!("{}.{}", $file, $ext);
|
||||
let path = Path::new(path_ref);
|
||||
|
||||
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! str2writer {
|
||||
(
|
||||
$file:expr,
|
||||
$ext:expr
|
||||
) => {{
|
||||
let path_ref = &format!("{}.{}", $file, $ext);
|
||||
let path = Path::new(path_ref);
|
||||
if path.exists() {
|
||||
log::info!("File exists, replacing");
|
||||
std::fs::remove_file(path).unwrap();
|
||||
}
|
||||
|
||||
File::create(path)
|
||||
}};
|
||||
}
|
||||
impl<E: Element> State<E>
|
||||
where
|
||||
E: serde::de::DeserializeOwned,
|
||||
E: serde::Serialize,
|
||||
{
|
||||
/// Save the state to the provided file path using the given [StateFormat](StateFormat).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The file extension will be added automatically depending on the state format.
|
||||
pub fn save(self, file: &str, format: &StateFormat) -> std::io::Result<()> {
|
||||
match format {
|
||||
StateFormat::BinGz => self.save_bingz(file),
|
||||
StateFormat::Bin => self.save_bin(file),
|
||||
StateFormat::JsonGz => self.save_jsongz(file),
|
||||
#[cfg(feature = "msgpack")]
|
||||
StateFormat::MpkGz => self.save_mpkgz(file),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the state from the provided file path using the given [StateFormat](StateFormat).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The file extension will be added automatically depending on the state format.
|
||||
pub fn load(file: &str, format: &StateFormat) -> Result<Self, StateError> {
|
||||
match format {
|
||||
StateFormat::BinGz => Self::load_bingz(file),
|
||||
StateFormat::Bin => Self::load_bin(file),
|
||||
StateFormat::JsonGz => Self::load_jsongz(file),
|
||||
#[cfg(feature = "msgpack")]
|
||||
StateFormat::MpkGz => Self::load_mpkgz(file),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_jsongz(self, file: &str) -> std::io::Result<()> {
|
||||
let writer = str2writer!(file, "json.gz")?;
|
||||
let writer = GzEncoder::new(writer, Compression::default());
|
||||
serde_json::to_writer(writer, &self).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_jsongz(file: &str) -> Result<Self, StateError> {
|
||||
let reader = str2reader!(file, "json.gz")?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = serde_json::from_reader(reader).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
fn save_mpkgz(self, file: &str) -> std::io::Result<()> {
|
||||
let writer = str2writer!(file, "mpk.gz")?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
rmp_serde::encode::write(&mut writer, &self).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
fn load_mpkgz(file: &str) -> Result<Self, StateError> {
|
||||
let reader = str2reader!(file, "mpk.gz")?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = rmp_serde::decode::from_read(reader).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn save_bingz(self, file: &str) -> std::io::Result<()> {
|
||||
let config = Self::bin_config();
|
||||
let writer = str2writer!(file, "bin.gz")?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
|
||||
bincode::serde::encode_into_std_write(&self, &mut writer, config).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_bingz(file: &str) -> Result<Self, StateError> {
|
||||
let reader = str2reader!(file, "bin.gz")?;
|
||||
let mut reader = GzDecoder::new(reader);
|
||||
let state =
|
||||
bincode::serde::decode_from_std_read(&mut reader, Self::bin_config()).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn save_bin(self, file: &str) -> std::io::Result<()> {
|
||||
let buf = bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap();
|
||||
|
||||
let mut writer = str2writer!(file, "bin")?;
|
||||
std::io::Write::write_all(&mut writer, &buf).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_bin(file: &str) -> Result<Self, StateError> {
|
||||
let mut reader = str2reader!(file, "bin")?;
|
||||
let mut buf = Vec::new();
|
||||
std::io::Read::read_to_end(&mut reader, &mut buf).unwrap();
|
||||
let state =
|
||||
bincode::serde::decode_borrowed_from_slice(&buf, Self::bin_config()).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -313,6 +131,7 @@ mod tests {
|
|||
let params_before_2 = list_param_ids(&model_2);
|
||||
|
||||
let state = model_1.state();
|
||||
|
||||
model_2 = model_2.load(&state).unwrap();
|
||||
let params_after_2 = list_param_ids(&model_2);
|
||||
|
||||
|
@ -320,80 +139,7 @@ mod tests {
|
|||
assert_eq!(params_before_1, params_after_2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_to_binary() {
|
||||
let model_1 = create_model();
|
||||
let model_2 = create_model();
|
||||
let params_before_1 = list_param_ids(&model_1);
|
||||
let params_before_2 = list_param_ids(&model_2);
|
||||
|
||||
// To & From Bytes
|
||||
let bytes = model_1.state().to_bin().unwrap();
|
||||
let model_2 = model_2.load(&State::from_bin(&bytes).unwrap()).unwrap();
|
||||
|
||||
// Verify.
|
||||
let params_after_2 = list_param_ids(&model_2);
|
||||
assert_ne!(params_before_1, params_before_2);
|
||||
assert_eq!(params_before_1, params_after_2);
|
||||
}
|
||||
|
||||
pub fn create_model() -> nn::Linear<TestBackend> {
|
||||
nn::LinearConfig::new(32, 32).with_bias(true).init()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests_save_load {
|
||||
use super::tests::create_model;
|
||||
use super::*;
|
||||
use crate::module::Module;
|
||||
|
||||
static FILE_PATH: &str = "/tmp/test_state";
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_jsongz_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::JsonGz)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_bin_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::Bin)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_bingz_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::BinGz)
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_mpkgz_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::MpkGz)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_bin_on_disk() {
|
||||
let model = create_model();
|
||||
model
|
||||
.state()
|
||||
.save("/tmp/model_compare", &StateFormat::Bin)
|
||||
.unwrap();
|
||||
let bytes = std::fs::read("/tmp/model_compare.bin").unwrap();
|
||||
let state = State::from_bin(&bytes).unwrap();
|
||||
|
||||
assert_eq!(state, model.state());
|
||||
}
|
||||
|
||||
fn test_can_save_and_load_from_file(format: StateFormat) {
|
||||
let model_before = create_model();
|
||||
let state_before = model_before.state();
|
||||
state_before.clone().save(FILE_PATH, &format).unwrap();
|
||||
|
||||
let model_after = create_model()
|
||||
.load(&State::load(FILE_PATH, &format).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state_after = model_after.state();
|
||||
assert_eq!(state_before, state_after);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
use super::{RecordSettings, Recorder, RecorderError};
|
||||
use crate::alloc::string::ToString;
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
||||
/// Trait to define a family of types which can be recorded using any [settings](RecordSettings).
|
||||
pub trait Record: Send + Sync {
|
||||
type Item<S: RecordSettings>;
|
||||
|
||||
/// Convert the current record into the corresponding item that follows the given [settings](RecordSettings).
|
||||
fn into_item<S: RecordSettings>(self) -> Self::Item<S>;
|
||||
/// Convert the given item into a record.
|
||||
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self;
|
||||
|
||||
/// Record using the given [settings](RecordSettings).
|
||||
fn record<S>(self, args: RecordArgs<S>) -> RecordOutputResult<S>
|
||||
where
|
||||
Self: Sized,
|
||||
S: RecordSettings,
|
||||
Self::Item<S>: Serialize + DeserializeOwned,
|
||||
{
|
||||
let metadata = BurnMetadata::new(
|
||||
core::any::type_name::<S::FloatElem>().to_string(),
|
||||
core::any::type_name::<S::IntElem>().to_string(),
|
||||
core::any::type_name::<S::Recorder>().to_string(),
|
||||
env!("CARGO_PKG_VERSION").to_string(),
|
||||
format!("{:?}", S::default()),
|
||||
);
|
||||
let item = self.into_item::<S>();
|
||||
let record = BurnRecord::new(item, metadata);
|
||||
|
||||
RecorderType::<S>::record(record, args)
|
||||
}
|
||||
|
||||
/// Load the record using the given [settings](RecordSettings).
|
||||
fn load<S>(args: LoadArgs<S>) -> Result<Self, RecorderError>
|
||||
where
|
||||
Self: Sized,
|
||||
S: RecordSettings,
|
||||
Self::Item<S>: Serialize + DeserializeOwned,
|
||||
{
|
||||
let record: BurnRecord<Self::Item<S>> =
|
||||
RecorderType::<S>::load(args.clone()).map_err(|err| {
|
||||
let message = match err {
|
||||
RecorderError::FileNotFound(_) => return err,
|
||||
RecorderError::Unknown(message) => message,
|
||||
};
|
||||
let record = RecorderType::<S>::load::<BurnRecordNoItem>(args);
|
||||
|
||||
let message = match record {
|
||||
Ok(record) => format!(
|
||||
"Unable to load record with settings {:?}, found metadata {:?}, err: {:?}",
|
||||
S::default(),
|
||||
record.metadata,
|
||||
message
|
||||
),
|
||||
Err(_err) => message,
|
||||
};
|
||||
RecorderError::Unknown(message)
|
||||
})?;
|
||||
|
||||
Ok(Self::from_item(record.item))
|
||||
}
|
||||
}
|
||||
|
||||
/// Record arguments for the given settings.
|
||||
pub type RecordArgs<S> = <<S as RecordSettings>::Recorder as Recorder>::RecordArgs;
|
||||
/// Record loading arguments for the given settings.
|
||||
pub type LoadArgs<S> = <<S as RecordSettings>::Recorder as Recorder>::LoadArgs;
|
||||
/// Record output result for the given settings.
|
||||
pub type RecordOutputResult<S> =
|
||||
Result<<<S as RecordSettings>::Recorder as Recorder>::RecordOutput, RecorderError>;
|
||||
/// Recorder for the given settings.
|
||||
pub type RecorderType<S> = <S as RecordSettings>::Recorder;
|
||||
|
||||
#[derive(new, Debug, Serialize, Deserialize)]
|
||||
struct BurnMetadata {
|
||||
float: String,
|
||||
int: String,
|
||||
format: String,
|
||||
version: String,
|
||||
settings: String,
|
||||
}
|
||||
|
||||
#[derive(new, Serialize, Deserialize)]
|
||||
struct BurnRecord<I> {
|
||||
item: I,
|
||||
metadata: BurnMetadata,
|
||||
}
|
||||
|
||||
#[derive(new, Serialize, Deserialize)]
|
||||
struct BurnRecordNoItem {
|
||||
metadata: BurnMetadata,
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
static FILE_PATH: &str = "/tmp/burn_test_record";
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::*;
|
||||
use crate::record::FileJsonGzRecorder;
|
||||
use burn_tensor::{Element, ElementConversion};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn err_when_invalid_object() {
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TestSettings<F> {
|
||||
float: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: Element + Serialize + DeserializeOwned> RecordSettings for TestSettings<F> {
|
||||
type FloatElem = F;
|
||||
type IntElem = i32;
|
||||
type Recorder = FileJsonGzRecorder;
|
||||
}
|
||||
|
||||
#[derive(new, Serialize, Deserialize)]
|
||||
struct Item<S: RecordSettings> {
|
||||
value: S::FloatElem,
|
||||
}
|
||||
|
||||
impl<D: RecordSettings> Record for Item<D> {
|
||||
type Item<S: RecordSettings> = Item<S>;
|
||||
|
||||
fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
|
||||
Item {
|
||||
value: self.value.elem(),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
|
||||
Item {
|
||||
value: item.value.elem(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let item = Item::<TestSettings<f32>>::new(16.elem());
|
||||
|
||||
// Serialize in f32.
|
||||
item.record::<TestSettings<f32>>(FILE_PATH.into()).unwrap();
|
||||
// Can't deserialize u8 into f32.
|
||||
Item::<TestSettings<f32>>::load::<TestSettings<u8>>(FILE_PATH.into()).unwrap();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
use super::{bin_config, Recorder, RecorderError};
|
||||
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from files.
|
||||
pub trait FileRecorder:
|
||||
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
|
||||
{
|
||||
}
|
||||
|
||||
/// File recorder using the [bincode format](bincode).
|
||||
pub struct FileBinRecorder;
|
||||
/// File recorder using the [bincode format](bincode) compressed with gzip.
|
||||
pub struct FileBinGzRecorder;
|
||||
/// File recorder using the json format compressed with gzip.
|
||||
pub struct FileJsonGzRecorder;
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
/// File recorder using the [message pack](rmp_serde) format compressed with gzip.
|
||||
pub struct FileMpkGzRecorder;
|
||||
|
||||
impl FileRecorder for FileBinGzRecorder {}
|
||||
impl FileRecorder for FileBinRecorder {}
|
||||
impl FileRecorder for FileJsonGzRecorder {}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
impl FileRecorder for FileMpkGzRecorder {}
|
||||
|
||||
macro_rules! str2reader {
|
||||
(
|
||||
$file:expr,
|
||||
$ext:expr
|
||||
) => {{
|
||||
$file.set_extension($ext);
|
||||
let path = $file.as_path();
|
||||
|
||||
File::open(path).map_err(|err| match err.kind() {
|
||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
||||
_ => RecorderError::Unknown(err.to_string()),
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! str2writer {
|
||||
(
|
||||
$file:expr,
|
||||
$ext:expr
|
||||
) => {{
|
||||
$file.set_extension($ext);
|
||||
let path = $file.as_path();
|
||||
|
||||
if path.exists() {
|
||||
log::info!("File exists, replacing");
|
||||
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
}
|
||||
|
||||
File::create(path).map_err(|err| match err.kind() {
|
||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
||||
_ => RecorderError::Unknown(err.to_string()),
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
impl Recorder for FileBinGzRecorder {
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn record<Obj: Serialize + DeserializeOwned>(
|
||||
obj: Obj,
|
||||
mut file: PathBuf,
|
||||
) -> Result<(), RecorderError> {
|
||||
let config = bin_config();
|
||||
let writer = str2writer!(file, "bin.gz")?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
|
||||
bincode::serde::encode_into_std_write(&obj, &mut writer, config)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
|
||||
let reader = str2reader!(file, "bin.gz")?;
|
||||
let mut reader = GzDecoder::new(reader);
|
||||
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Recorder for FileBinRecorder {
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn record<Obj: Serialize + DeserializeOwned>(
|
||||
obj: Obj,
|
||||
mut file: PathBuf,
|
||||
) -> Result<(), RecorderError> {
|
||||
let config = bin_config();
|
||||
let mut writer = str2writer!(file, "bin")?;
|
||||
bincode::serde::encode_into_std_write(&obj, &mut writer, config)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
|
||||
let mut reader = str2reader!(file, "bin")?;
|
||||
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Recorder for FileJsonGzRecorder {
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn record<Obj: Serialize + DeserializeOwned>(
|
||||
obj: Obj,
|
||||
mut file: PathBuf,
|
||||
) -> Result<(), RecorderError> {
|
||||
let writer = str2writer!(file, "json.gz")?;
|
||||
let writer = GzEncoder::new(writer, Compression::default());
|
||||
serde_json::to_writer(writer, &obj)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
|
||||
let reader = str2reader!(file, "json.gz")?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = serde_json::from_reader(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
impl Recorder for FileMpkGzRecorder {
|
||||
type SaveArgs = PathBuf;
|
||||
type SaveOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save<Obj: Serialize + DeserializeOwned>(
|
||||
obj: Obj,
|
||||
mut file: PathBuf,
|
||||
) -> Result<(), RecorderError> {
|
||||
let writer = str2writer!(file, "mpk.gz")?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
rmp_serde::encode::write(&mut writer, &obj)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
|
||||
let reader = str2reader!(file, "mpk.gz")?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = rmp_serde::decode::from_read(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Module, nn, TestBackend};
|
||||
|
||||
static FILE_PATH: &str = "/tmp/burn_test_file_recorder";
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_jsongz_format() {
|
||||
test_can_save_and_load::<FileJsonGzRecorder>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_bin_format() {
|
||||
test_can_save_and_load::<FileBinRecorder>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_bingz_format() {
|
||||
test_can_save_and_load::<FileBinGzRecorder>()
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
#[test]
|
||||
fn test_can_save_and_load_mpkgz_format() {
|
||||
test_can_save_and_load::<FileMpkGzRecorder>()
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder: FileRecorder>() {
|
||||
let model_before = create_model();
|
||||
let state_before = model_before.state();
|
||||
Recorder::record(state_before.clone(), FILE_PATH.into()).unwrap();
|
||||
|
||||
let model_after = create_model()
|
||||
.load(&Recorder::load(FILE_PATH.into()).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state_after = model_after.state();
|
||||
assert_eq!(state_before, state_after);
|
||||
}
|
||||
|
||||
pub fn create_model() -> nn::Linear<TestBackend> {
|
||||
nn::LinearConfig::new(32, 32).with_bias(true).init()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
use super::{bin_config, Recorder, RecorderError};
|
||||
use alloc::vec::Vec;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from bytes.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is especialy useful in no_std environment where weights are stored directly in
|
||||
/// compiled binaries.
|
||||
pub trait InMemoryRecorder:
|
||||
Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
|
||||
{
|
||||
}
|
||||
|
||||
/// In memory recorder using the [bincode format](bincode).
|
||||
pub struct InMemoryBinRecorder;
|
||||
|
||||
impl InMemoryRecorder for InMemoryBinRecorder {}
|
||||
|
||||
impl Recorder for InMemoryBinRecorder {
|
||||
type RecordArgs = ();
|
||||
type RecordOutput = Vec<u8>;
|
||||
type LoadArgs = Vec<u8>;
|
||||
|
||||
fn record<Obj: Serialize + DeserializeOwned>(
|
||||
obj: Obj,
|
||||
_args: Self::RecordArgs,
|
||||
) -> Result<Vec<u8>, RecorderError> {
|
||||
Ok(bincode::serde::encode_to_vec(obj, bin_config()).unwrap())
|
||||
}
|
||||
|
||||
fn load<Obj: Serialize + DeserializeOwned>(args: Self::LoadArgs) -> Result<Obj, RecorderError> {
|
||||
let state = bincode::serde::decode_borrowed_from_slice(&args, bin_config()).unwrap();
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Module, nn, TestBackend};
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_bin_format() {
|
||||
test_can_save_and_load::<InMemoryBinRecorder>()
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder: InMemoryRecorder>() {
|
||||
let model_before = create_model();
|
||||
let state_before = model_before.state();
|
||||
let bytes = Recorder::record(state_before.clone(), ()).unwrap();
|
||||
|
||||
let model_after = create_model()
|
||||
.load(&Recorder::load(bytes).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state_after = model_after.state();
|
||||
assert_eq!(state_before, state_after);
|
||||
}
|
||||
|
||||
pub fn create_model() -> nn::Linear<TestBackend> {
|
||||
nn::LinearConfig::new(32, 32).with_bias(true).init()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
mod state;
|
||||
|
||||
mod base;
|
||||
mod memory;
|
||||
mod recorder;
|
||||
mod settings;
|
||||
|
||||
pub use base::*;
|
||||
pub use memory::*;
|
||||
pub use recorder::*;
|
||||
pub use settings::*;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod file;
|
||||
#[cfg(feature = "std")]
|
||||
pub use file::*;
|
|
@ -0,0 +1,42 @@
|
|||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
|
||||
pub trait Recorder: Send + Sync {
|
||||
/// Arguments used to record objects.
|
||||
type RecordArgs: Clone;
|
||||
/// Record output type.
|
||||
type RecordOutput;
|
||||
/// Arguments used to load recorded objects.
|
||||
type LoadArgs: Clone;
|
||||
|
||||
fn record<Item: Serialize + DeserializeOwned>(
|
||||
item: Item,
|
||||
args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError>;
|
||||
/// Load an item from the given arguments.
|
||||
fn load<Item: Serialize + DeserializeOwned>(
|
||||
args: Self::LoadArgs,
|
||||
) -> Result<Item, RecorderError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RecorderError {
|
||||
FileNotFound(String),
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for RecorderError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("{self:?}").as_str())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for RecorderError {}
|
||||
|
||||
pub(crate) fn bin_config() -> bincode::config::Configuration {
|
||||
bincode::config::standard()
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
use super::Recorder;
|
||||
use burn_tensor::Element;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
pub trait RecordSettings: Send + Sync + core::fmt::Debug + core::default::Default {
|
||||
type FloatElem: Element + Serialize + DeserializeOwned;
|
||||
type IntElem: Element + Serialize + DeserializeOwned;
|
||||
type Recorder: Recorder;
|
||||
}
|
||||
|
||||
/// Default record settings.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DefaultRecordSettings;
|
||||
/// Training settings compatible with no-std inference.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct NoStdTrainingRecordSettings;
|
||||
/// Inference settings compatible with no-std.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct NoStdInferenceRecordSettings;
|
||||
|
||||
impl RecordSettings for DefaultRecordSettings {
|
||||
#[cfg(feature = "std")]
|
||||
type FloatElem = half::f16;
|
||||
#[cfg(not(feature = "std"))]
|
||||
type FloatElem = f32;
|
||||
type IntElem = i16;
|
||||
#[cfg(feature = "std")]
|
||||
type Recorder = crate::record::FileBinGzRecorder;
|
||||
#[cfg(not(feature = "std"))]
|
||||
type Recorder = crate::record::InMemoryBinRecorder;
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl RecordSettings for NoStdTrainingRecordSettings {
|
||||
type FloatElem = f32;
|
||||
type IntElem = i32;
|
||||
type Recorder = crate::record::FileBinRecorder;
|
||||
}
|
||||
|
||||
impl RecordSettings for NoStdInferenceRecordSettings {
|
||||
type FloatElem = f32;
|
||||
type IntElem = i32;
|
||||
type Recorder = crate::record::InMemoryBinRecorder;
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
use super::{Record, RecordSettings};
|
||||
use crate::module::State;
|
||||
use burn_tensor::Element;
|
||||
|
||||
impl<T: Element> Record for State<T> {
|
||||
type Item<S: RecordSettings> = State<S::FloatElem>;
|
||||
|
||||
fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
|
||||
self.convert::<S::FloatElem>()
|
||||
}
|
||||
|
||||
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
|
||||
item.convert()
|
||||
}
|
||||
}
|
|
@ -1,20 +1,19 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::module::State;
|
||||
use burn_core::tensor::Element;
|
||||
use burn_core::record::Record;
|
||||
use std::sync::{mpsc, Arc};
|
||||
|
||||
enum Message<E> {
|
||||
Save(usize, State<E>),
|
||||
enum Message<R> {
|
||||
Save(usize, R),
|
||||
End,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct CheckpointerThread<T> {
|
||||
checkpointer: Arc<dyn Checkpointer<T> + Send + Sync>,
|
||||
receiver: mpsc::Receiver<Message<T>>,
|
||||
struct CheckpointerThread<R> {
|
||||
checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>,
|
||||
receiver: mpsc::Receiver<Message<R>>,
|
||||
}
|
||||
|
||||
impl<T> CheckpointerThread<T> {
|
||||
impl<R: Record> CheckpointerThread<R> {
|
||||
fn run(self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
|
@ -33,8 +32,8 @@ pub struct AsyncCheckpointer<E> {
|
|||
handler: Option<std::thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl<E: Element + 'static> AsyncCheckpointer<E> {
|
||||
pub fn new(checkpointer: Arc<dyn Checkpointer<E> + Send + Sync>) -> Self {
|
||||
impl<R: Record + 'static> AsyncCheckpointer<R> {
|
||||
pub fn new(checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>) -> Self {
|
||||
// Only on checkpoint can be done in advance.
|
||||
let (sender, receiver) = mpsc::sync_channel(0);
|
||||
let thread = CheckpointerThread::new(checkpointer.clone(), receiver);
|
||||
|
@ -48,17 +47,17 @@ impl<E: Element + 'static> AsyncCheckpointer<E> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E> Checkpointer<E> for AsyncCheckpointer<E>
|
||||
impl<R> Checkpointer<R> for AsyncCheckpointer<R>
|
||||
where
|
||||
E: Element + Sync + 'static,
|
||||
R: Record + 'static,
|
||||
{
|
||||
fn save(&self, epoch: usize, state: State<E>) -> Result<(), CheckpointerError> {
|
||||
self.sender.send(Message::Save(epoch, state)).unwrap();
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
self.sender.send(Message::Save(epoch, record)).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn restore(&self, epoch: usize) -> Result<State<E>, CheckpointerError> {
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
|
||||
self.checkpointer.restore(epoch)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
use burn_core::module::{State, StateError};
|
||||
use burn_core::{
|
||||
module::StateError,
|
||||
record::{Record, RecorderError},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CheckpointerError {
|
||||
IOError(std::io::Error),
|
||||
RecorderError(RecorderError),
|
||||
StateError(StateError),
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
pub trait Checkpointer<E> {
|
||||
fn save(&self, epoch: usize, state: State<E>) -> Result<(), CheckpointerError>;
|
||||
fn restore(&self, epoch: usize) -> Result<State<E>, CheckpointerError>;
|
||||
pub trait Checkpointer<R: Record> {
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError>;
|
||||
}
|
||||
|
|
|
@ -1,25 +1,33 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::module::{State, StateFormat};
|
||||
use burn_core::tensor::Element;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct FileCheckpointer<P> {
|
||||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::record::{FileRecorder, Record, RecordSettings};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
pub struct FileCheckpointer<S>
|
||||
where
|
||||
S: RecordSettings,
|
||||
S::Recorder: FileRecorder,
|
||||
{
|
||||
directory: String,
|
||||
name: String,
|
||||
num_keep: usize,
|
||||
format: StateFormat,
|
||||
_precision: P,
|
||||
settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<P: Element> FileCheckpointer<P> {
|
||||
pub fn new(directory: &str, name: &str, num_keep: usize, format: StateFormat) -> Self {
|
||||
impl<S> FileCheckpointer<S>
|
||||
where
|
||||
S: RecordSettings,
|
||||
S::Recorder: FileRecorder,
|
||||
{
|
||||
pub fn new(directory: &str, name: &str, num_keep: usize) -> Self {
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
|
||||
Self {
|
||||
directory: directory.to_string(),
|
||||
name: name.to_string(),
|
||||
num_keep,
|
||||
format,
|
||||
_precision: P::default(),
|
||||
settings: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
fn path_for_epoch(&self, epoch: usize) -> String {
|
||||
|
@ -27,19 +35,20 @@ impl<P: Element> FileCheckpointer<P> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E, P> Checkpointer<E> for FileCheckpointer<P>
|
||||
impl<R, S> Checkpointer<R> for FileCheckpointer<S>
|
||||
where
|
||||
P: serde::Serialize + serde::de::DeserializeOwned + Element,
|
||||
E: Element,
|
||||
R: Record,
|
||||
S: RecordSettings,
|
||||
S::Recorder: FileRecorder,
|
||||
R::Item<S>: Serialize + DeserializeOwned,
|
||||
{
|
||||
fn save(&self, epoch: usize, state: State<E>) -> Result<(), CheckpointerError> {
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!("Saving checkpoint {} to {}", epoch, file_path);
|
||||
|
||||
state
|
||||
.convert::<P>()
|
||||
.save(&file_path, &self.format)
|
||||
.map_err(CheckpointerError::IOError)?;
|
||||
record
|
||||
.record(file_path.into())
|
||||
.map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
if self.num_keep > epoch {
|
||||
return Ok(());
|
||||
|
@ -55,13 +64,11 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn restore(&self, epoch: usize) -> Result<State<E>, CheckpointerError> {
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
|
||||
let record = R::load(file_path.into()).map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
let state =
|
||||
State::<P>::load(&file_path, &self.format).map_err(CheckpointerError::StateError)?;
|
||||
|
||||
Ok(state.convert())
|
||||
Ok(record)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::checkpoint::Checkpointer;
|
||||
use crate::LearnerCallback;
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::module::{ADModule, State};
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::tensor::backend::{ADBackend, Backend};
|
||||
|
||||
|
@ -24,8 +24,8 @@ where
|
|||
pub(super) devices: Vec<B::Device>,
|
||||
}
|
||||
|
||||
type CheckpointModel<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
|
||||
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
|
||||
type CheckpointModel<B> = Option<Box<dyn Checkpointer<State<<B as Backend>::FloatElem>>>>;
|
||||
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<State<<B as Backend>::FloatElem>>>>;
|
||||
|
||||
impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
|
||||
where
|
||||
|
|
|
@ -6,10 +6,10 @@ use crate::metric::dashboard::cli::CLIDashboardRenderer;
|
|||
use crate::metric::dashboard::Dashboard;
|
||||
use crate::metric::{Adaptor, Metric, Numeric};
|
||||
use crate::AsyncTrainerCallback;
|
||||
use burn_core::module::{ADModule, StateFormat};
|
||||
use burn_core::module::{ADModule, State};
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::record::{FileRecorder, RecordSettings};
|
||||
use burn_core::tensor::backend::ADBackend;
|
||||
use burn_core::tensor::Element;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Struct to configure and create a [learner](Learner).
|
||||
|
@ -20,8 +20,8 @@ where
|
|||
B: ADBackend,
|
||||
{
|
||||
dashboard: Dashboard<T, V>,
|
||||
checkpointer_model: Option<Arc<dyn Checkpointer<B::FloatElem> + Send + Sync>>,
|
||||
checkpointer_optimizer: Option<Arc<dyn Checkpointer<B::FloatElem> + Send + Sync>>,
|
||||
checkpointer_model: Option<Arc<dyn Checkpointer<State<B::FloatElem>> + Send + Sync>>,
|
||||
checkpointer_optimizer: Option<Arc<dyn Checkpointer<State<B::FloatElem>> + Send + Sync>>,
|
||||
num_epochs: usize,
|
||||
checkpoint: Option<usize>,
|
||||
directory: String,
|
||||
|
@ -140,22 +140,20 @@ where
|
|||
/// The number of checkpoints to be keep should be set to a minimum of two to be safe, since
|
||||
/// they are saved and deleted asynchronously and a crash during training might make a
|
||||
/// checkpoint non-usable.
|
||||
pub fn with_file_checkpointer<P: Element + serde::de::DeserializeOwned + serde::Serialize>(
|
||||
mut self,
|
||||
num_keep: usize,
|
||||
format: StateFormat,
|
||||
) -> Self {
|
||||
self.checkpointer_model = Some(Arc::new(FileCheckpointer::<P>::new(
|
||||
pub fn with_file_checkpointer<S>(mut self, num_keep: usize) -> Self
|
||||
where
|
||||
S: RecordSettings + 'static,
|
||||
S::Recorder: FileRecorder,
|
||||
{
|
||||
self.checkpointer_model = Some(Arc::new(FileCheckpointer::<S>::new(
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"model",
|
||||
num_keep,
|
||||
format.clone(),
|
||||
)));
|
||||
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::<P>::new(
|
||||
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::<S>::new(
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"optim",
|
||||
num_keep,
|
||||
format,
|
||||
)));
|
||||
self
|
||||
}
|
||||
|
@ -172,7 +170,7 @@ where
|
|||
|
||||
let create_checkpointer = |checkpointer| match checkpointer {
|
||||
Some(checkpointer) => {
|
||||
let checkpointer: Box<dyn Checkpointer<B::FloatElem>> =
|
||||
let checkpointer: Box<dyn Checkpointer<State<B::FloatElem>>> =
|
||||
Box::new(AsyncCheckpointer::new(checkpointer));
|
||||
Some(checkpointer)
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -1,7 +1,7 @@
|
|||
use crate::model::Model;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::module::State;
|
||||
use burn::record::NoStdInferenceRecordSettings;
|
||||
use burn::record::Record;
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
|
||||
pub type Backend = NdArrayBackend<f32>;
|
||||
|
@ -11,7 +11,8 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
|
|||
/// Builds and loads trained parameters into the model.
|
||||
pub fn build_and_load_model() -> Model<Backend> {
|
||||
let model: Model<Backend> = Model::new();
|
||||
let state: State<f32> = State::from_bin(STATE_ENCODED).expect("Failed to decode state");
|
||||
let state = Record::load::<NoStdInferenceRecordSettings>(STATE_ENCODED.to_vec())
|
||||
.expect("Failed to decode state");
|
||||
|
||||
model
|
||||
.load(&state)
|
||||
|
|
|
@ -3,9 +3,10 @@ use std::sync::Arc;
|
|||
use crate::data::MNISTBatcher;
|
||||
use crate::model::Model;
|
||||
|
||||
use burn::module::{Module, StateFormat};
|
||||
use burn::module::Module;
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::{Adam, AdamConfig};
|
||||
use burn::record::{DefaultRecordSettings, NoStdTrainingRecordSettings, Record};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
|
||||
|
@ -65,7 +66,7 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
|
||||
.with_file_checkpointer::<DefaultRecordSettings>(2)
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
@ -79,7 +80,6 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
// We save a bin version of the model to be loaded with no_std environement.
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{ARTIFACT_DIR}/model"), &StateFormat::Bin)
|
||||
.record::<NoStdTrainingRecordSettings>(format!("{ARTIFACT_DIR}/model").into())
|
||||
.expect("Failed to save trained model");
|
||||
}
|
||||
|
|
|
@ -3,8 +3,9 @@ use std::sync::Arc;
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
module::{Module, State, StateFormat},
|
||||
tensor::backend::Backend,
|
||||
module::{Module, State},
|
||||
record::{DefaultRecordSettings, Record},
|
||||
tensor::{backend::Backend, f16},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -38,10 +39,8 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
|
|||
.init::<B>();
|
||||
|
||||
println!("Loading weights ...");
|
||||
let state = State::<burn::tensor::f16>::load(
|
||||
format!("{artifact_dir}/model").as_str(),
|
||||
&StateFormat::default(),
|
||||
)
|
||||
let state: State<f16> =
|
||||
Record::load::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
|
||||
.expect("Trained model weights");
|
||||
let model = model.load(&state.convert()).expect("Can load weights");
|
||||
let model = model.to_device(&device);
|
||||
|
|
|
@ -5,9 +5,10 @@ use crate::{
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
|
||||
module::{Module, StateFormat},
|
||||
module::Module,
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::{Sgd, SgdConfig},
|
||||
record::{DefaultRecordSettings, Record},
|
||||
tensor::backend::ADBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, CUDAMetric, LossMetric},
|
||||
|
@ -78,7 +79,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
|
||||
.with_file_checkpointer::<DefaultRecordSettings>(2)
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
@ -89,7 +90,6 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<burn::tensor::f16>()
|
||||
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
|
||||
.record::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
@ -14,7 +14,10 @@ use burn::{
|
|||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
use burn::{data::dataset::transform::SamplerDataset, module::StateFormat};
|
||||
use burn::{
|
||||
data::dataset::transform::SamplerDataset,
|
||||
record::{DefaultRecordSettings, Record},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Config)]
|
||||
|
@ -76,7 +79,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
|
||||
.with_file_checkpointer::<DefaultRecordSettings>(2)
|
||||
.devices(vec![device])
|
||||
.grads_accumulation(16)
|
||||
.num_epochs(config.num_epochs)
|
||||
|
@ -88,7 +91,6 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<burn::tensor::f16>()
|
||||
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
|
||||
.record::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
|
||||
.unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue