mirror of https://github.com/tracel-ai/burn.git
Feat/recorder/custom device (#1165)
This commit is contained in:
parent
e9d1656687
commit
eaa4dc3207
|
@ -37,10 +37,10 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
|||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.load(format!("{artifact_dir}/model").into(), &device)
|
||||
.expect("Trained model should exist");
|
||||
|
||||
let model = config.model.init_with::<B>(record).to_device(&device);
|
||||
let model = config.model.init_with::<B>(record);
|
||||
|
||||
let label = item.label;
|
||||
let batcher = MNISTBatcher::new(device);
|
||||
|
|
|
@ -23,7 +23,7 @@ Now that you have a trained model saved to your disk, you can easily load it in
|
|||
// Load model in full precision from MessagePack file
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
|
||||
model
|
||||
.load_file(model_path, &recorder)
|
||||
.load_file(model_path, &recorder, device)
|
||||
.expect("Should be able to load the model weights from the provided file");
|
||||
```
|
||||
|
||||
|
@ -96,7 +96,7 @@ Afterwards, the model can just as easily be loaded from the record saved on disk
|
|||
```rust, ignore
|
||||
// Load model record on the backend's default device
|
||||
let record: ModelRecord<MyBackend> = NamedMpkFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(model_path.into())
|
||||
.load(model_path.into(), device)
|
||||
.expect("Should be able to load the model weights from the provided file");
|
||||
|
||||
// Directly initialize a new model with the loaded record/weights
|
||||
|
@ -133,7 +133,7 @@ static MODEL_BYTES: &[u8] = include_bytes!("path/to/model.bin");
|
|||
|
||||
// Load model binary record in full precision
|
||||
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.load(MODEL_BYTES.to_vec())
|
||||
.load(MODEL_BYTES.to_vec(), device)
|
||||
.expect("Should be able to load model the model weights from bytes");
|
||||
|
||||
// Load that record with the model
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::{record::Record, LearningRate};
|
||||
|
||||
/// Learning rate scheduler defines how the learning rate will evolve during training.
|
||||
pub trait LrScheduler: Send + Sync {
|
||||
pub trait LrScheduler<B: Backend>: Send + Sync {
|
||||
/// Scheduler associative type to be used when saving and loading the state.
|
||||
type Record: Record;
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Perform the scheduler step, potentially updating its state, and returning the effective
|
||||
/// learning rate.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use burn_tensor::backend::Backend;
|
||||
|
||||
use super::LrScheduler;
|
||||
use crate::LearningRate;
|
||||
|
||||
|
@ -17,7 +19,7 @@ impl From<LearningRate> for ConstantLr {
|
|||
}
|
||||
}
|
||||
|
||||
impl LrScheduler for ConstantLr {
|
||||
impl<B: Backend> LrScheduler<B> for ConstantLr {
|
||||
type Record = ();
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
|
@ -31,7 +33,7 @@ impl LrScheduler for ConstantLr {
|
|||
}
|
||||
}
|
||||
|
||||
impl LrScheduler for LearningRate {
|
||||
impl<B: Backend> LrScheduler<B> for LearningRate {
|
||||
type Record = ();
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use super::LrScheduler;
|
||||
|
@ -37,7 +39,7 @@ impl NoamLrSchedulerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
impl LrScheduler for NoamLrScheduler {
|
||||
impl<B: Backend> LrScheduler<B> for NoamLrScheduler {
|
||||
type Record = usize;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
|
@ -61,6 +63,8 @@ impl LrScheduler for NoamLrScheduler {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::TestBackend;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -72,7 +76,7 @@ mod tests {
|
|||
let mut lr_current = 0.0;
|
||||
|
||||
for _ in 0..warmup_steps {
|
||||
let lr = scheduler.step();
|
||||
let lr = LrScheduler::<TestBackend>::step(&mut scheduler);
|
||||
assert!(
|
||||
lr > lr_current,
|
||||
"Learning rate should increase before the warmup_steps is reached."
|
||||
|
@ -81,7 +85,7 @@ mod tests {
|
|||
}
|
||||
|
||||
for _ in 0..warmup_steps {
|
||||
let lr = scheduler.step();
|
||||
let lr = LrScheduler::<TestBackend>::step(&mut scheduler);
|
||||
assert!(
|
||||
lr < lr_current,
|
||||
"Learning rate should decrease after the warmup_steps is reached."
|
||||
|
|
|
@ -82,7 +82,7 @@ macro_rules! module {
|
|||
/// ```
|
||||
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
||||
/// Type to save and load the module.
|
||||
type Record: Record;
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Return all the devices found in the underneath module tree added to the given vector
|
||||
/// without duplicates.
|
||||
|
@ -164,11 +164,15 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
///
|
||||
/// The file extension is automatically added depending on the file recorder provided, you
|
||||
/// don't have to specify it.
|
||||
fn save_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
|
||||
fn save_file<FR, PB>(
|
||||
self,
|
||||
file_path: PB,
|
||||
recorder: &FR,
|
||||
) -> Result<(), crate::record::RecorderError> {
|
||||
) -> Result<(), crate::record::RecorderError>
|
||||
where
|
||||
FR: crate::record::FileRecorder<B>,
|
||||
PB: Into<std::path::PathBuf>,
|
||||
{
|
||||
let record = Self::into_record(self);
|
||||
recorder.record(record, file_path.into())
|
||||
}
|
||||
|
@ -183,12 +187,17 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
///
|
||||
/// The file extension is automatically added depending on the file recorder provided, you
|
||||
/// don't have to specify it.
|
||||
fn load_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
|
||||
fn load_file<FR, PB>(
|
||||
self,
|
||||
file_path: PB,
|
||||
recorder: &FR,
|
||||
) -> Result<Self, crate::record::RecorderError> {
|
||||
let record = recorder.load(file_path.into())?;
|
||||
device: &B::Device,
|
||||
) -> Result<Self, crate::record::RecorderError>
|
||||
where
|
||||
FR: crate::record::FileRecorder<B>,
|
||||
PB: Into<std::path::PathBuf>,
|
||||
{
|
||||
let record = recorder.load(file_path.into(), device)?;
|
||||
|
||||
Ok(self.load_record(record))
|
||||
}
|
||||
|
|
|
@ -34,14 +34,14 @@ impl<'de> serde::Deserialize<'de> for ConstantRecord {
|
|||
}
|
||||
}
|
||||
|
||||
impl Record for ConstantRecord {
|
||||
impl<B: Backend> Record<B> for ConstantRecord {
|
||||
type Item<S: PrecisionSettings> = ConstantRecord;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
item
|
||||
}
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ mod tests {
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Tensor;
|
||||
use burn_tensor::{Device, Tensor};
|
||||
|
||||
use crate::TestBackend;
|
||||
use crate::{
|
||||
|
@ -226,21 +226,31 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn tensor_load_record_setting() {
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &Default::default());
|
||||
let device: &Device<TestAutodiffBackend> = &Default::default();
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
|
||||
|
||||
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
|
||||
let bytes = byte_recorder
|
||||
.record(tensor.clone().into_record(), ())
|
||||
.unwrap();
|
||||
let bytes = Recorder::<TestAutodiffBackend>::record(
|
||||
&byte_recorder,
|
||||
tensor.clone().into_record(),
|
||||
(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let no_grad_is_require_grad = tensor
|
||||
.clone()
|
||||
.no_grad()
|
||||
.load_record(byte_recorder.load(bytes.clone()).unwrap())
|
||||
.load_record(
|
||||
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
|
||||
.unwrap(),
|
||||
)
|
||||
.is_require_grad();
|
||||
|
||||
let with_default_is_require_grad = tensor
|
||||
.load_record(byte_recorder.load(bytes).unwrap())
|
||||
.load_record(
|
||||
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
|
||||
.unwrap(),
|
||||
)
|
||||
.is_require_grad();
|
||||
|
||||
assert!(!no_grad_is_require_grad);
|
||||
|
|
|
@ -228,7 +228,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_load_record_setting() {
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &Default::default());
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device);
|
||||
|
||||
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
|
||||
let bytes = byte_recorder
|
||||
|
@ -237,12 +238,12 @@ mod tests {
|
|||
|
||||
let no_grad_is_require_grad = Param::from(tensor.clone())
|
||||
.no_grad()
|
||||
.load_record(byte_recorder.load(bytes.clone()).unwrap())
|
||||
.load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
|
||||
.value
|
||||
.is_require_grad();
|
||||
|
||||
let with_default_is_require_grad = Param::from(tensor)
|
||||
.load_record(byte_recorder.load(bytes).unwrap())
|
||||
.load_record(byte_recorder.load(bytes, &device).unwrap())
|
||||
.value
|
||||
.is_require_grad();
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ where
|
|||
B: AutodiffBackend,
|
||||
{
|
||||
/// Optimizer associative type to be used when saving and loading the state.
|
||||
type Record: Record;
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Perform the optimizer step using the given learning rate and gradients.
|
||||
/// The updated module is returned.
|
||||
|
|
|
@ -18,7 +18,7 @@ where
|
|||
B: AutodiffBackend,
|
||||
{
|
||||
optim: O,
|
||||
records: HashMap<ParamId, AdaptorRecord<O, B::InnerBackend>>,
|
||||
records: HashMap<ParamId, AdaptorRecord<O, B>>,
|
||||
module: PhantomData<M>,
|
||||
grad_clipping: Option<GradientClipping>,
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ where
|
|||
M: AutodiffModule<B>,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
type Record = HashMap<ParamId, AdaptorRecord<O, B::InnerBackend>>;
|
||||
type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
|
||||
|
||||
fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M {
|
||||
let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
|
||||
|
@ -102,7 +102,7 @@ where
|
|||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
optimizer: &'a O,
|
||||
records: &'a mut HashMap<ParamId, AdaptorRecord<O, B::InnerBackend>>,
|
||||
records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
|
||||
grads: &'a mut GradientsParams,
|
||||
lr: LearningRate,
|
||||
phantom: PhantomData<M>,
|
||||
|
|
|
@ -11,7 +11,7 @@ where
|
|||
B: Backend,
|
||||
{
|
||||
/// The state of the optimizer. It also implements [record](Record), so that it can be saved.
|
||||
type State<const D: usize>: Record + Clone + 'static;
|
||||
type State<const D: usize>: Record<B> + Clone + 'static;
|
||||
|
||||
/// The optimizer step is performed for one tensor at a time with its gradient and state.
|
||||
///
|
||||
|
|
|
@ -3,29 +3,37 @@ use crate::{
|
|||
optim::SimpleOptimizer,
|
||||
record::{PrecisionSettings, Record},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::backend::AutodiffBackend;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.
|
||||
///
|
||||
/// Records are versioned for backward compatibility, so old records can be loaded.
|
||||
pub enum AdaptorRecord<O: SimpleOptimizer<B>, B: Backend> {
|
||||
pub enum AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
/// Version 1.
|
||||
V1(AdaptorRecordV1<O, B>),
|
||||
V1(AdaptorRecordV1<O, B::InnerBackend>),
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItem<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
pub enum AdaptorRecordItem<
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
S: PrecisionSettings,
|
||||
> {
|
||||
/// Version 1.
|
||||
V1(AdaptorRecordItemV1<O, B, S>),
|
||||
V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
|
||||
}
|
||||
|
||||
impl<O, B> Record for AdaptorRecord<O, B>
|
||||
impl<O, B> Record<B> for AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;
|
||||
|
||||
|
@ -35,17 +43,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item)),
|
||||
AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, B> Clone for AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
|
@ -56,8 +64,8 @@ where
|
|||
|
||||
impl<O, B> AdaptorRecord<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
/// Converts the record into the optimizer state.
|
||||
///
|
||||
|
|
|
@ -53,28 +53,28 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
|
|||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
/// Rank 1.
|
||||
Rank1(<O::State<1> as Record>::Item<S>),
|
||||
Rank1(<O::State<1> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 2.
|
||||
Rank2(<O::State<2> as Record>::Item<S>),
|
||||
Rank2(<O::State<2> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 3.
|
||||
Rank3(<O::State<3> as Record>::Item<S>),
|
||||
Rank3(<O::State<3> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 4.
|
||||
Rank4(<O::State<4> as Record>::Item<S>),
|
||||
Rank4(<O::State<4> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 5.
|
||||
Rank5(<O::State<5> as Record>::Item<S>),
|
||||
Rank5(<O::State<5> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 6.
|
||||
Rank6(<O::State<6> as Record>::Item<S>),
|
||||
Rank6(<O::State<6> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 7.
|
||||
Rank7(<O::State<7> as Record>::Item<S>),
|
||||
Rank7(<O::State<7> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 8.
|
||||
Rank8(<O::State<8> as Record>::Item<S>),
|
||||
Rank8(<O::State<8> as Record<B>>::Item<S>),
|
||||
}
|
||||
|
||||
impl<O, B> AdaptorRecordV1<O, B>
|
||||
|
@ -134,7 +134,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<O, B> Record for AdaptorRecordV1<O, B>
|
||||
impl<O, B> Record<B> for AdaptorRecordV1<O, B>
|
||||
where
|
||||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
|
@ -154,31 +154,31 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
AdaptorRecordItemV1::Rank1(item) => {
|
||||
AdaptorRecordV1::Rank1(<O::State<1> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank2(item) => {
|
||||
AdaptorRecordV1::Rank2(<O::State<2> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank3(item) => {
|
||||
AdaptorRecordV1::Rank3(<O::State<3> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank4(item) => {
|
||||
AdaptorRecordV1::Rank4(<O::State<4> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank5(item) => {
|
||||
AdaptorRecordV1::Rank5(<O::State<5> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank6(item) => {
|
||||
AdaptorRecordV1::Rank6(<O::State<6> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank7(item) => {
|
||||
AdaptorRecordV1::Rank7(<O::State<7> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank8(item) => {
|
||||
AdaptorRecordV1::Rank8(<O::State<8> as Record>::from_item(item))
|
||||
AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
pub use burn_derive::Record;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use super::PrecisionSettings;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
|
||||
pub trait Record: Send + Sync {
|
||||
pub trait Record<B: Backend>: Send + Sync {
|
||||
/// Type of the item that can be serialized and deserialized.
|
||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
|
||||
|
||||
|
@ -12,5 +13,5 @@ pub trait Record: Send + Sync {
|
|||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
|
||||
|
||||
/// Convert the given item into a record.
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
||||
use burn_tensor::backend::Backend;
|
||||
use core::marker::PhantomData;
|
||||
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
@ -6,8 +7,8 @@ use std::io::{BufReader, BufWriter};
|
|||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from files.
|
||||
pub trait FileRecorder:
|
||||
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
|
||||
pub trait FileRecorder<B: Backend>:
|
||||
Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
|
||||
{
|
||||
/// File extension of the format used by the recorder.
|
||||
fn file_extension() -> &'static str;
|
||||
|
@ -52,34 +53,34 @@ pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
|
|||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> FileRecorder for BinGzFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"bin.gz"
|
||||
}
|
||||
}
|
||||
impl<S: PrecisionSettings> FileRecorder for BinFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"bin"
|
||||
}
|
||||
}
|
||||
impl<S: PrecisionSettings> FileRecorder for JsonGzFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"json.gz"
|
||||
}
|
||||
}
|
||||
impl<S: PrecisionSettings> FileRecorder for PrettyJsonFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"json"
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> FileRecorder for NamedMpkGzFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"mpk.gz"
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> FileRecorder for NamedMpkFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"mpk"
|
||||
}
|
||||
|
@ -89,7 +90,7 @@ macro_rules! str2reader {
|
|||
(
|
||||
$file:expr
|
||||
) => {{
|
||||
$file.set_extension(<Self as FileRecorder>::file_extension());
|
||||
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
|
||||
let path = $file.as_path();
|
||||
|
||||
File::open(path)
|
||||
|
@ -105,7 +106,7 @@ macro_rules! str2writer {
|
|||
(
|
||||
$file:expr
|
||||
) => {{
|
||||
$file.set_extension(<Self as FileRecorder>::file_extension());
|
||||
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
|
||||
let path = $file.as_path();
|
||||
|
||||
if path.exists() {
|
||||
|
@ -122,7 +123,7 @@ macro_rules! str2writer {
|
|||
}};
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for BinGzFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
|
@ -153,7 +154,7 @@ impl<S: PrecisionSettings> Recorder for BinGzFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for BinFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
|
@ -179,7 +180,7 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for JsonGzFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
|
@ -208,7 +209,7 @@ impl<S: PrecisionSettings> Recorder for JsonGzFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for PrettyJsonFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
|
@ -234,7 +235,7 @@ impl<S: PrecisionSettings> Recorder for PrettyJsonFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for NamedMpkGzFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
|
@ -263,7 +264,7 @@ impl<S: PrecisionSettings> Recorder for NamedMpkGzFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for NamedMpkFileRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
|
@ -346,14 +347,18 @@ mod tests {
|
|||
test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder: FileRecorder>(recorder: Recorder) {
|
||||
fn test_can_save_and_load<Recorder>(recorder: Recorder)
|
||||
where
|
||||
Recorder: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let model_before = create_model(&device);
|
||||
recorder
|
||||
.record(model_before.clone().into_record(), file_path())
|
||||
.unwrap();
|
||||
|
||||
let model_after = create_model(&device).load_record(recorder.load(file_path()).unwrap());
|
||||
let model_after =
|
||||
create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
|
||||
|
||||
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
|
||||
let model_bytes_before = byte_recorder
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::backend::Backend;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from bytes.
|
||||
|
@ -8,8 +9,8 @@ use serde::{de::DeserializeOwned, Serialize};
|
|||
///
|
||||
/// This is especially useful in no_std environment where weights are stored directly in
|
||||
/// compiled binaries.
|
||||
pub trait BytesRecorder:
|
||||
Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
|
||||
pub trait BytesRecorder<B: Backend>:
|
||||
Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -19,9 +20,9 @@ pub struct BinBytesRecorder<S: PrecisionSettings> {
|
|||
_settings: core::marker::PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> BytesRecorder for BinBytesRecorder<S> {}
|
||||
impl<S: PrecisionSettings, B: Backend> BytesRecorder<B> for BinBytesRecorder<S> {}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for BinBytesRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinBytesRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = ();
|
||||
type RecordOutput = Vec<u8>;
|
||||
|
@ -48,10 +49,10 @@ pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
|
|||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl<S: PrecisionSettings> BytesRecorder for NamedMpkBytesRecorder<S> {}
|
||||
impl<S: PrecisionSettings, B: Backend> BytesRecorder<B> for NamedMpkBytesRecorder<S> {}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl<S: PrecisionSettings> Recorder for NamedMpkBytesRecorder<S> {
|
||||
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = ();
|
||||
type RecordOutput = Vec<u8>;
|
||||
|
@ -87,14 +88,17 @@ mod tests {
|
|||
test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder: BytesRecorder>(recorder: Recorder) {
|
||||
fn test_can_save_and_load<Recorder>(recorder: Recorder)
|
||||
where
|
||||
Recorder: BytesRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let model1 = create_model::<TestBackend>(&device);
|
||||
let model2 = create_model::<TestBackend>(&device);
|
||||
let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
|
||||
let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
|
||||
|
||||
let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap());
|
||||
let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());
|
||||
let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
|
||||
|
||||
assert_ne!(bytes1, bytes2);
|
||||
|
|
|
@ -16,55 +16,76 @@ use crate::module::{Param, ParamId};
|
|||
use burn_tensor::{DataSerialize, Element};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
impl Record for () {
|
||||
impl<B> Record<B> for ()
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ();
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>) -> Self {}
|
||||
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
|
||||
}
|
||||
|
||||
impl<T: Record> Record for Vec<T> {
|
||||
impl<T, B> Record<B> for Vec<T>
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.into_iter().map(Record::into_item).collect()
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
item.into_iter().map(Record::from_item).collect()
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.into_iter()
|
||||
.map(|i| Record::from_item(i, device))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Record> Record for Option<T> {
|
||||
impl<T, B> Record<B> for Option<T>
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Option<T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.map(Record::into_item)
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
item.map(Record::from_item)
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.map(|i| Record::from_item(i, device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: Record + core::fmt::Debug> Record for [T; N] {
|
||||
impl<const N: usize, T, B> Record<B> for [T; N]
|
||||
where
|
||||
T: Record<B> + core::fmt::Debug,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.map(Record::into_item).into_iter().collect()
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
item.into_iter()
|
||||
.map(Record::from_item)
|
||||
.map(|i| Record::from_item(i, device))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap_or_else(|_| panic!("An arrar of size {N}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Record> Record for HashMap<ParamId, T> {
|
||||
impl<T, B> Record<B> for HashMap<ParamId, T>
|
||||
where
|
||||
T: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
|
@ -75,23 +96,27 @@ impl<T: Record> Record for HashMap<ParamId, T> {
|
|||
items
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
let mut record = HashMap::with_capacity(item.len());
|
||||
item.into_iter().for_each(|(id, item)| {
|
||||
record.insert(ParamId::from(id), T::from_item(item));
|
||||
record.insert(ParamId::from(id), T::from_item(item, device));
|
||||
});
|
||||
record
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Element> Record for DataSerialize<E> {
|
||||
impl<E, B> Record<B> for DataSerialize<E>
|
||||
where
|
||||
E: Element,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = DataSerialize<S::FloatElem>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self.convert()
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
item.convert()
|
||||
}
|
||||
}
|
||||
|
@ -103,57 +128,72 @@ pub struct ParamSerde<T> {
|
|||
param: T,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
|
||||
impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Param::new(
|
||||
ParamId::from(item.id),
|
||||
Tensor::from_item(item.param).require_grad(), // Same behavior as when we create a new
|
||||
// Param from a tensor.
|
||||
Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
|
||||
// Param from a tensor.
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Int>> {
|
||||
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Param::new(ParamId::from(item.id), Tensor::from_item(item.param))
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Param::new(
|
||||
ParamId::from(item.id),
|
||||
Tensor::from_item(item.param, device),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Bool>> {
|
||||
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item::<S>())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Param::new(ParamId::from(item.id), Tensor::from_item::<S>(item.param))
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Param::new(
|
||||
ParamId::from(item.id),
|
||||
Tensor::from_item::<S>(item.param, device),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Type that can be serialized as is without any conversion.
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl Record for $type {
|
||||
impl<B: Backend> Record<B> for $type {
|
||||
type Item<S: PrecisionSettings> = $type;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
self
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
item
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use core::any::type_name;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use alloc::format;
|
||||
use alloc::string::{String, ToString};
|
||||
use burn_tensor::backend::Backend;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
||||
use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
|
||||
|
@ -13,7 +15,9 @@ use super::{
|
|||
};
|
||||
|
||||
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
|
||||
pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone {
|
||||
pub trait Recorder<B: Backend>:
|
||||
Send + Sync + core::default::Default + core::fmt::Debug + Clone
|
||||
{
|
||||
/// Type of the settings used by the recorder.
|
||||
type Settings: PrecisionSettings;
|
||||
|
||||
|
@ -36,11 +40,14 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
|
|||
/// # Returns
|
||||
///
|
||||
/// The output of the recording.
|
||||
fn record<R: Record>(
|
||||
fn record<R>(
|
||||
&self,
|
||||
record: R,
|
||||
args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError> {
|
||||
) -> Result<Self::RecordOutput, RecorderError>
|
||||
where
|
||||
R: Record<B>,
|
||||
{
|
||||
let item = record.into_item::<Self::Settings>();
|
||||
let item = BurnRecord::new::<Self>(item);
|
||||
|
||||
|
@ -48,12 +55,15 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
|
|||
}
|
||||
|
||||
/// Load an item from the given arguments.
|
||||
fn load<R: Record>(&self, args: Self::LoadArgs) -> Result<R, RecorderError> {
|
||||
let item: BurnRecord<R::Item<Self::Settings>> =
|
||||
fn load<R>(&self, args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
|
||||
where
|
||||
R: Record<B>,
|
||||
{
|
||||
let item: BurnRecord<R::Item<Self::Settings>, B> =
|
||||
self.load_item(args.clone()).map_err(|err| {
|
||||
if let Ok(record) = self.load_item::<BurnRecordNoItem>(args.clone()) {
|
||||
let mut message = "Unable to load record.".to_string();
|
||||
let metadata = recorder_metadata::<Self>();
|
||||
let metadata = recorder_metadata::<Self, B>();
|
||||
if metadata.float != record.metadata.float {
|
||||
message += format!(
|
||||
"\nMetadata has a different float type: Actual {:?}, Expected {:?}",
|
||||
|
@ -91,7 +101,7 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
|
|||
err
|
||||
})?;
|
||||
|
||||
Ok(R::from_item(item.item))
|
||||
Ok(R::from_item(item.item, device))
|
||||
}
|
||||
|
||||
/// Saves an item.
|
||||
|
@ -123,10 +133,16 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
|
|||
/// # Returns
|
||||
///
|
||||
/// The loaded item.
|
||||
fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>;
|
||||
fn load_item<I>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>
|
||||
where
|
||||
I: DeserializeOwned;
|
||||
}
|
||||
|
||||
fn recorder_metadata<R: Recorder>() -> BurnMetadata {
|
||||
fn recorder_metadata<R, B>() -> BurnMetadata
|
||||
where
|
||||
R: Recorder<B>,
|
||||
B: Backend,
|
||||
{
|
||||
BurnMetadata::new(
|
||||
type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
|
||||
type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
|
||||
|
@ -181,15 +197,17 @@ pub struct BurnMetadata {
|
|||
|
||||
/// Record that can be saved by a [Recorder](Recorder).
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct BurnRecord<I> {
|
||||
pub struct BurnRecord<I, B: Backend> {
|
||||
/// Metadata of the record.
|
||||
pub metadata: BurnMetadata,
|
||||
|
||||
/// Item to record.
|
||||
pub item: I,
|
||||
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<I> BurnRecord<I> {
|
||||
impl<I, B: Backend> BurnRecord<I, B> {
|
||||
/// Creates a new record.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -199,10 +217,14 @@ impl<I> BurnRecord<I> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The new record.
|
||||
pub fn new<R: Recorder>(item: I) -> Self {
|
||||
let metadata = recorder_metadata::<R>();
|
||||
pub fn new<R: Recorder<B>>(item: I) -> Self {
|
||||
let metadata = recorder_metadata::<R, B>();
|
||||
|
||||
Self { metadata, item }
|
||||
Self {
|
||||
metadata,
|
||||
item,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,8 +276,10 @@ pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;
|
|||
mod tests {
|
||||
static FILE_PATH: &str = "/tmp/burn_test_record";
|
||||
|
||||
use crate::TestBackend;
|
||||
|
||||
use super::*;
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{Device, ElementConversion};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
|
@ -265,7 +289,11 @@ mod tests {
|
|||
value: S::FloatElem,
|
||||
}
|
||||
|
||||
impl<D: PrecisionSettings> Record for Item<D> {
|
||||
impl<D, B> Record<B> for Item<D>
|
||||
where
|
||||
D: PrecisionSettings,
|
||||
B: Backend,
|
||||
{
|
||||
type Item<S: PrecisionSettings> = Item<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
|
@ -274,7 +302,7 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
|
||||
Item {
|
||||
value: item.value.elem(),
|
||||
}
|
||||
|
@ -282,15 +310,19 @@ mod tests {
|
|||
}
|
||||
|
||||
let item = Item::<FullPrecisionSettings>::new(16.elem());
|
||||
let device: Device<TestBackend> = Default::default();
|
||||
|
||||
// Serialize in f32.
|
||||
let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
|
||||
recorder.record(item, FILE_PATH.into()).unwrap();
|
||||
Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();
|
||||
|
||||
// Can't deserialize f32 into f16.
|
||||
let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
|
||||
recorder
|
||||
.load::<Item<FullPrecisionSettings>>(FILE_PATH.into())
|
||||
.unwrap();
|
||||
Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(
|
||||
&recorder,
|
||||
FILE_PATH.into(),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ impl<'de> Deserialize<'de> for BoolTensorSerde {
|
|||
|
||||
// --- RECORD IMPLEMENTATIONS --- //
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Tensor<B, D> {
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
|
||||
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
|
@ -96,12 +96,12 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D> {
|
|||
FloatTensorSerde::new(self.into_data().convert().serialize())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Tensor::from_data(item.data.convert::<B::FloatElem>(), &B::Device::default())
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Tensor::from_data(item.data.convert::<B::FloatElem>(), device)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
|
||||
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
|
@ -112,12 +112,12 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
|
|||
IntTensorSerde::new(self.into_data().convert().serialize())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Tensor::from_data(item.data.convert(), &B::Device::default())
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Tensor::from_data(item.data.convert(), device)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
|
||||
type Item<S: PrecisionSettings> = BoolTensorSerde;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
|
@ -128,7 +128,7 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
|
|||
BoolTensorSerde::new(self.into_data().serialize())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Tensor::from_data(item.data, &B::Device::default())
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Tensor::from_data(item.data, device)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -192,7 +192,7 @@ mod tests {
|
|||
|
||||
fn deserialize_with_new_optional_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder,
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf = file_path(format!("deserialize_with_new_optional_field-{name}"));
|
||||
|
@ -206,7 +206,8 @@ mod tests {
|
|||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result = recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone());
|
||||
let result =
|
||||
recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
|
@ -218,7 +219,7 @@ mod tests {
|
|||
recorder: R,
|
||||
) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder,
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf =
|
||||
|
@ -234,7 +235,7 @@ mod tests {
|
|||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone());
|
||||
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
|
@ -243,7 +244,7 @@ mod tests {
|
|||
|
||||
fn deserialize_with_new_constant_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder,
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf = file_path(format!("deserialize_with_new_constant_field-{name}"));
|
||||
|
@ -257,7 +258,8 @@ mod tests {
|
|||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result = recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone());
|
||||
let result =
|
||||
recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
|
@ -269,7 +271,7 @@ mod tests {
|
|||
recorder: R,
|
||||
) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder,
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf =
|
||||
|
@ -285,7 +287,7 @@ mod tests {
|
|||
recorder
|
||||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone());
|
||||
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
|
@ -294,7 +296,7 @@ mod tests {
|
|||
|
||||
fn deserialize_with_new_field_order<R>(name: &str, recorder: R) -> Result<(), RecorderError>
|
||||
where
|
||||
R: FileRecorder,
|
||||
R: FileRecorder<TestBackend>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let file_path: PathBuf = file_path(format!("deserialize_with_new_field_order-{name}"));
|
||||
|
@ -309,7 +311,8 @@ mod tests {
|
|||
.record(model.into_record(), file_path.clone())
|
||||
.unwrap();
|
||||
|
||||
let result = recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone());
|
||||
let result =
|
||||
recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone(), &device);
|
||||
std::fs::remove_file(file_path).ok();
|
||||
|
||||
result?;
|
||||
|
|
|
@ -22,12 +22,19 @@ struct RecordDeriveCodegen {
|
|||
name_item: Ident,
|
||||
gen: StructRecordItemCodegen,
|
||||
generics: Generics,
|
||||
has_backend: bool,
|
||||
}
|
||||
|
||||
impl RecordDeriveCodegen {
|
||||
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
let name_record = ast.ident.clone();
|
||||
let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span());
|
||||
let has_backend = ast
|
||||
.generics
|
||||
.type_params()
|
||||
.map(|param| param.ident == "B")
|
||||
.reduce(|accum, is_backend| is_backend || accum)
|
||||
.unwrap_or(false);
|
||||
|
||||
Self {
|
||||
name_record,
|
||||
|
@ -39,6 +46,7 @@ impl RecordDeriveCodegen {
|
|||
.collect(),
|
||||
),
|
||||
generics: ast.generics.clone(),
|
||||
has_backend,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,7 +59,8 @@ impl RecordDeriveCodegen {
|
|||
generics.params.push(param);
|
||||
}
|
||||
|
||||
self.gen.gen_item_type(&self.name_item, &generics)
|
||||
self.gen
|
||||
.gen_item_type(&self.name_item, &generics, self.has_backend)
|
||||
}
|
||||
|
||||
/// Generate the implementation for the Record trait.
|
||||
|
@ -61,12 +70,18 @@ impl RecordDeriveCodegen {
|
|||
let (_, ty_generics_item, _) = item_generics.split_for_impl();
|
||||
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
|
||||
|
||||
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
|
||||
impl_generic
|
||||
} else {
|
||||
quote! { #impl_generics }
|
||||
};
|
||||
|
||||
let name_item = &self.name_item;
|
||||
let into_item_fn = self.gen.gen_into_item(name_item);
|
||||
let from_item_fn = self.gen.gen_from_item();
|
||||
|
||||
quote! {
|
||||
impl #impl_generics burn::record::Record for #name #ty_generics #where_clause {
|
||||
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
|
||||
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
|
||||
|
||||
#into_item_fn
|
||||
|
@ -76,6 +91,20 @@ impl RecordDeriveCodegen {
|
|||
}
|
||||
}
|
||||
|
||||
fn impl_generics(&self) -> Option<TokenStream> {
|
||||
if self.has_backend {
|
||||
return None;
|
||||
}
|
||||
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
let mut generics = self.generics.clone();
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
|
||||
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
|
||||
|
||||
Some(quote! {#impl_generics})
|
||||
}
|
||||
|
||||
fn record_item_generics(&self) -> Generics {
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.generics.clone();
|
||||
|
@ -83,6 +112,11 @@ impl RecordDeriveCodegen {
|
|||
generics.params.push(param);
|
||||
}
|
||||
|
||||
if !self.has_backend {
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,12 @@ use syn::Generics;
|
|||
/// Basic trait to be implemented for record generation.
|
||||
pub(crate) trait RecordItemCodegen {
|
||||
/// Generate the record item type (i.e a struct)
|
||||
fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream;
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream;
|
||||
/// Generate the into_item function.
|
||||
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
|
||||
/// Generate the from item function.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::shared::field::FieldTypeAnalyzer;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::Generics;
|
||||
use syn::{parse_quote, Generics};
|
||||
|
||||
use super::codegen::RecordItemCodegen;
|
||||
|
||||
|
@ -11,7 +11,12 @@ pub(crate) struct StructRecordItemCodegen {
|
|||
}
|
||||
|
||||
impl RecordItemCodegen for StructRecordItemCodegen {
|
||||
fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream {
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream {
|
||||
let mut fields = quote! {};
|
||||
let mut bounds = quote! {};
|
||||
|
||||
|
@ -21,15 +26,25 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
|
||||
fields.extend(quote! {
|
||||
/// Field to be serialized.
|
||||
pub #name: <#ty as burn::record::Record>::Item<S>,
|
||||
pub #name: <#ty as burn::record::Record<B>>::Item<S>,
|
||||
});
|
||||
bounds.extend(quote! {
|
||||
<#ty as burn::record::Record>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
});
|
||||
}
|
||||
let bound = bounds.to_string();
|
||||
|
||||
let (generics, _, generics_where) = generics.split_for_impl();
|
||||
let (generics, generics_where) = if !has_backend {
|
||||
let mut generics = generics.clone();
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
let (generics, _, generics_where) = generics.split_for_impl();
|
||||
(quote! { #generics }, quote! { #generics_where })
|
||||
} else {
|
||||
let (generics, _, generics_where) = generics.split_for_impl();
|
||||
(quote! { #generics }, quote! { #generics_where })
|
||||
};
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record item type for the module.
|
||||
|
@ -49,7 +64,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
let name = &field.field.ident;
|
||||
|
||||
body_into_item.extend(quote! {
|
||||
#name: burn::record::Record::into_item::<S>(self.#name),
|
||||
#name: burn::record::Record::<B>::into_item::<S>(self.#name),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -69,12 +84,12 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
let name = &field.field.ident;
|
||||
|
||||
body_from_item.extend(quote! {
|
||||
#name: burn::record::Record::from_item::<S>(item.#name),
|
||||
#name: burn::record::Record::<B>::from_item::<S>(item.#name, device),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Self {
|
||||
#body_from_item
|
||||
}
|
||||
|
|
|
@ -49,6 +49,9 @@ pub struct BurnGraph<PS: PrecisionSettings> {
|
|||
graph_output_types: Vec<Type>,
|
||||
}
|
||||
|
||||
// The backend used for recording.
|
||||
type Backend = burn_ndarray::NdArray;
|
||||
|
||||
impl<PS: PrecisionSettings> BurnGraph<PS> {
|
||||
/// Register a new operation node into the graph.
|
||||
///
|
||||
|
@ -96,14 +99,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
|
||||
match record_type {
|
||||
RecordType::PrettyJson => {
|
||||
PrettyJsonFileRecorder::<PS>::new()
|
||||
.save_item(
|
||||
BurnRecord::new::<PrettyJsonFileRecorder<PS>>(StructMap(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let recorder = PrettyJsonFileRecorder::<PS>::new();
|
||||
|
||||
Recorder::<Backend>::save_item(
|
||||
&recorder,
|
||||
BurnRecord::<_, Backend>::new::<PrettyJsonFileRecorder<PS>>(StructMap(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!embed_states,
|
||||
|
@ -116,14 +121,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
);
|
||||
}
|
||||
RecordType::NamedMpkGz => {
|
||||
NamedMpkGzFileRecorder::<PS>::new()
|
||||
.save_item(
|
||||
BurnRecord::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let recorder = NamedMpkGzFileRecorder::<PS>::new();
|
||||
|
||||
Recorder::<Backend>::save_item(
|
||||
&recorder,
|
||||
BurnRecord::<_, Backend>::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!embed_states,
|
||||
|
@ -136,14 +143,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
}
|
||||
|
||||
RecordType::NamedMpk => {
|
||||
NamedMpkFileRecorder::<PS>::new()
|
||||
.save_item(
|
||||
BurnRecord::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let recorder = NamedMpkFileRecorder::<PS>::new();
|
||||
|
||||
Recorder::<Backend>::save_item(
|
||||
&recorder,
|
||||
BurnRecord::<_, Backend>::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!embed_states,
|
||||
|
@ -157,14 +166,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
}
|
||||
|
||||
RecordType::Bincode => {
|
||||
BinFileRecorder::<PS>::new()
|
||||
.save_item(
|
||||
BurnRecord::new::<BinFileRecorder<PS>>(StructTuple(BurnGraphState::new(
|
||||
&self.nodes,
|
||||
))),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let recorder = BinFileRecorder::<PS>::new();
|
||||
|
||||
Recorder::<Backend>::save_item(
|
||||
&recorder,
|
||||
BurnRecord::<_, Backend>::new::<BinFileRecorder<PS>>(StructTuple(
|
||||
BurnGraphState::new(&self.nodes),
|
||||
)),
|
||||
out_file.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if embed_states {
|
||||
self.register_record_embed(out_file);
|
||||
|
@ -349,14 +360,14 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
_blank_!();
|
||||
impl<B: Backend> Default for Model<B> {
|
||||
fn default() -> Self {
|
||||
Self::from_file(#file)
|
||||
Self::from_file(#file, &Default::default())
|
||||
}
|
||||
}
|
||||
_blank_!();
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn from_file(file: &str) -> Self {
|
||||
pub fn from_file(file: &str, device: &B::Device) -> Self {
|
||||
let record = #recorder_ty::new()
|
||||
.load(file.into())
|
||||
.load(file.into(), device)
|
||||
.expect("Record file to exist.");
|
||||
Self::new_with(record)
|
||||
}
|
||||
|
@ -373,7 +384,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
self.imports.register("burn::record::BinBytesRecorder");
|
||||
|
||||
let mut file = file;
|
||||
file.set_extension(BinFileRecorder::<PS>::file_extension());
|
||||
file.set_extension(<BinFileRecorder<PS> as FileRecorder<Backend>>::file_extension());
|
||||
let file = file.to_str().unwrap();
|
||||
self.default = Some(quote! {
|
||||
_blank_!();
|
||||
|
@ -381,14 +392,14 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
_blank_!();
|
||||
impl<B: Backend> Default for Model<B> {
|
||||
fn default() -> Self {
|
||||
Self::from_embedded()
|
||||
Self::from_embedded(&Default::default())
|
||||
}
|
||||
}
|
||||
_blank_!();
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn from_embedded() -> Self {
|
||||
pub fn from_embedded(device: &B::Device) -> Self {
|
||||
let record = BinBytesRecorder::<#precision_ty>::default()
|
||||
.load(EMBEDDED_STATES.to_vec())
|
||||
.load(EMBEDDED_STATES.to_vec(), device)
|
||||
.expect("Failed to decode state");
|
||||
|
||||
Self::new_with(record)
|
||||
|
|
|
@ -1,26 +1,35 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::record::Record;
|
||||
use burn_core::{record::Record, tensor::backend::Backend};
|
||||
use std::sync::mpsc;
|
||||
|
||||
enum Message<R> {
|
||||
Restore(usize, mpsc::SyncSender<Result<R, CheckpointerError>>),
|
||||
enum Message<R, B: Backend> {
|
||||
Restore(
|
||||
usize,
|
||||
B::Device,
|
||||
mpsc::SyncSender<Result<R, CheckpointerError>>,
|
||||
),
|
||||
Save(usize, R),
|
||||
Delete(usize),
|
||||
End,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct CheckpointerThread<C, R> {
|
||||
struct CheckpointerThread<C, R, B: Backend> {
|
||||
checkpointer: C,
|
||||
receiver: mpsc::Receiver<Message<R>>,
|
||||
receiver: mpsc::Receiver<Message<R, B>>,
|
||||
}
|
||||
|
||||
impl<C: Checkpointer<R>, R: Record> CheckpointerThread<C, R> {
|
||||
impl<C, R, B> CheckpointerThread<C, R, B>
|
||||
where
|
||||
C: Checkpointer<R, B>,
|
||||
R: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
fn run(self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::Restore(epoch, callback) => {
|
||||
let record = self.checkpointer.restore(epoch);
|
||||
Message::Restore(epoch, device, callback) => {
|
||||
let record = self.checkpointer.restore(epoch, &device);
|
||||
callback
|
||||
.send(record)
|
||||
.expect("Can send response through callback channel.");
|
||||
|
@ -42,12 +51,16 @@ impl<C: Checkpointer<R>, R: Record> CheckpointerThread<C, R> {
|
|||
}
|
||||
|
||||
/// Async checkpointer.
|
||||
pub struct AsyncCheckpointer<Record> {
|
||||
sender: mpsc::SyncSender<Message<Record>>,
|
||||
pub struct AsyncCheckpointer<Record, B: Backend> {
|
||||
sender: mpsc::SyncSender<Message<Record, B>>,
|
||||
handler: Option<std::thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl<R: Record + 'static> AsyncCheckpointer<R> {
|
||||
impl<R, B> AsyncCheckpointer<R, B>
|
||||
where
|
||||
R: Record<B> + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
/// Create a new async checkpointer.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -59,7 +72,7 @@ impl<R: Record + 'static> AsyncCheckpointer<R> {
|
|||
/// The async checkpointer.
|
||||
pub fn new<C>(checkpointer: C) -> Self
|
||||
where
|
||||
C: Checkpointer<R> + Send + 'static,
|
||||
C: Checkpointer<R, B> + Send + 'static,
|
||||
{
|
||||
// Only on checkpoint can be done in advance.
|
||||
let (sender, receiver) = mpsc::sync_channel(0);
|
||||
|
@ -70,9 +83,10 @@ impl<R: Record + 'static> AsyncCheckpointer<R> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R> Checkpointer<R> for AsyncCheckpointer<R>
|
||||
impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
|
||||
where
|
||||
R: Record + 'static,
|
||||
R: Record<B> + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
self.sender
|
||||
|
@ -82,10 +96,10 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::Restore(epoch, sender))
|
||||
.send(Message::Restore(epoch, device.clone(), sender))
|
||||
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
|
||||
|
||||
if let Ok(record) = receiver.recv() {
|
||||
|
@ -104,7 +118,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<E> Drop for AsyncCheckpointer<E> {
|
||||
impl<E, B> Drop for AsyncCheckpointer<E, B>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
self.sender
|
||||
.send(Message::End)
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use burn_core::record::{Record, RecorderError};
|
||||
use burn_core::{
|
||||
record::{Record, RecorderError},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
/// The error type for checkpointer.
|
||||
#[derive(Debug)]
|
||||
|
@ -14,7 +17,11 @@ pub enum CheckpointerError {
|
|||
}
|
||||
|
||||
/// The trait for checkpointer.
|
||||
pub trait Checkpointer<R: Record> {
|
||||
pub trait Checkpointer<R, B>
|
||||
where
|
||||
R: Record<B>,
|
||||
B: Backend,
|
||||
{
|
||||
/// Save the record.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -31,9 +38,10 @@ pub trait Checkpointer<R: Record> {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
/// * `device` - The device used to restore the record.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError>;
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::record::{FileRecorder, Record};
|
||||
use burn_core::{
|
||||
record::{FileRecorder, Record},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
/// The file checkpointer.
|
||||
pub struct FileCheckpointer<FR> {
|
||||
|
@ -30,10 +33,11 @@ impl<FR> FileCheckpointer<FR> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<FR, R> Checkpointer<R> for FileCheckpointer<FR>
|
||||
impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
|
||||
where
|
||||
R: Record,
|
||||
FR: FileRecorder,
|
||||
R: Record<B>,
|
||||
FR: FileRecorder<B>,
|
||||
B: Backend,
|
||||
{
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
|
@ -46,12 +50,12 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
|
||||
let record = self
|
||||
.recorder
|
||||
.load(file_path.into())
|
||||
.load(file_path.into(), device)
|
||||
.map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
Ok(record)
|
||||
|
|
|
@ -15,19 +15,26 @@ pub trait LearnerComponents {
|
|||
/// The backend in used for the training.
|
||||
type Backend: AutodiffBackend;
|
||||
/// The learning rate scheduler used for the training.
|
||||
type LrScheduler: LrScheduler;
|
||||
type LrScheduler: LrScheduler<Self::Backend>;
|
||||
/// The model to train.
|
||||
type Model: AutodiffModule<Self::Backend> + core::fmt::Display + 'static;
|
||||
/// The optimizer used for the training.
|
||||
type Optimizer: Optimizer<Self::Model, Self::Backend>;
|
||||
/// The checkpointer used for the model.
|
||||
type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record>;
|
||||
type CheckpointerModel: Checkpointer<
|
||||
<Self::Model as Module<Self::Backend>>::Record,
|
||||
Self::Backend,
|
||||
>;
|
||||
/// The checkpointer used for the optimizer.
|
||||
type CheckpointerOptimizer: Checkpointer<
|
||||
<Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
|
||||
Self::Backend,
|
||||
>;
|
||||
/// The checkpointer used for the scheduler.
|
||||
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
|
||||
type CheckpointerLrScheduler: Checkpointer<
|
||||
<Self::LrScheduler as LrScheduler<Self::Backend>>::Record,
|
||||
Self::Backend,
|
||||
>;
|
||||
type EventProcessor: EventProcessor + 'static;
|
||||
/// The strategy to save and delete checkpoints.
|
||||
type CheckpointerStrategy: CheckpointingStrategy;
|
||||
|
@ -50,12 +57,12 @@ impl<B, LR, M, O, CM, CO, CS, EP, S> LearnerComponents
|
|||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
LR: LrScheduler,
|
||||
LR: LrScheduler<B>,
|
||||
M: AutodiffModule<B> + core::fmt::Display + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
CM: Checkpointer<M::Record>,
|
||||
CO: Checkpointer<O::Record>,
|
||||
CS: Checkpointer<LR::Record>,
|
||||
CM: Checkpointer<M::Record, B>,
|
||||
CO: Checkpointer<O::Record, B>,
|
||||
CS: Checkpointer<LR::Record, B>,
|
||||
EP: EventProcessor + 'static,
|
||||
S: CheckpointingStrategy,
|
||||
{
|
||||
|
|
|
@ -6,6 +6,7 @@ use burn_core::lr_scheduler::LrScheduler;
|
|||
use burn_core::module::Module;
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::Device;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -79,23 +80,24 @@ impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
|
|||
model: LC::Model,
|
||||
optim: LC::Optimizer,
|
||||
scheduler: LC::LrScheduler,
|
||||
device: &Device<LC::Backend>,
|
||||
epoch: usize,
|
||||
) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
|
||||
let record = self
|
||||
.model
|
||||
.restore(epoch)
|
||||
.restore(epoch, device)
|
||||
.expect("Can load model checkpoint.");
|
||||
let model = model.load_record(record);
|
||||
|
||||
let record = self
|
||||
.optim
|
||||
.restore(epoch)
|
||||
.restore(epoch, device)
|
||||
.expect("Can load optimizer checkpoint.");
|
||||
let optim = optim.load_record(record);
|
||||
|
||||
let record = self
|
||||
.lr_scheduler
|
||||
.restore(epoch)
|
||||
.restore(epoch, device)
|
||||
.expect("Can load learning rate scheduler checkpoint.");
|
||||
let scheduler = scheduler.load_record(record);
|
||||
|
||||
|
|
|
@ -29,16 +29,16 @@ where
|
|||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
S: LrScheduler,
|
||||
S: LrScheduler<B>,
|
||||
{
|
||||
// Not that complex and very convenient when the traits are
|
||||
// already constrained correctly. Extracting in another type
|
||||
// would be more complex.
|
||||
#[allow(clippy::type_complexity)]
|
||||
checkpointers: Option<(
|
||||
AsyncCheckpointer<M::Record>,
|
||||
AsyncCheckpointer<O::Record>,
|
||||
AsyncCheckpointer<S::Record>,
|
||||
AsyncCheckpointer<M::Record, B>,
|
||||
AsyncCheckpointer<O::Record, B>,
|
||||
AsyncCheckpointer<S::Record, B>,
|
||||
)>,
|
||||
num_epochs: usize,
|
||||
checkpoint: Option<usize>,
|
||||
|
@ -62,7 +62,7 @@ where
|
|||
V: Send + Sync + 'static,
|
||||
M: AutodiffModule<B> + core::fmt::Display + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
S: LrScheduler,
|
||||
S: LrScheduler<B>,
|
||||
{
|
||||
/// Creates a new learner builder.
|
||||
///
|
||||
|
@ -235,7 +235,8 @@ where
|
|||
/// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
|
||||
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
|
||||
where
|
||||
FR: FileRecorder + 'static,
|
||||
FR: FileRecorder<B> + 'static,
|
||||
FR: FileRecorder<B::InnerBackend> + 'static,
|
||||
O::Record: 'static,
|
||||
M::Record: 'static,
|
||||
S::Record: 'static,
|
||||
|
@ -281,9 +282,9 @@ where
|
|||
S,
|
||||
M,
|
||||
O,
|
||||
AsyncCheckpointer<M::Record>,
|
||||
AsyncCheckpointer<O::Record>,
|
||||
AsyncCheckpointer<S::Record>,
|
||||
AsyncCheckpointer<M::Record, B>,
|
||||
AsyncCheckpointer<O::Record, B>,
|
||||
AsyncCheckpointer<S::Record, B>,
|
||||
FullEventProcessor<T, V>,
|
||||
Box<dyn CheckpointingStrategy>,
|
||||
>,
|
||||
|
|
|
@ -135,6 +135,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
self.model,
|
||||
self.optim,
|
||||
self.lr_scheduler,
|
||||
&Default::default(), // Load the checkpoint on the default device.
|
||||
checkpoint,
|
||||
);
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ use burn::data::dataset::source::huggingface::MNISTItem;
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
module::Module,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
@ -12,10 +11,10 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
|||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.load(format!("{artifact_dir}/model").into(), &device)
|
||||
.expect("Trained model should exist");
|
||||
|
||||
let model = config.model.init_with::<B>(record).to_device(&device);
|
||||
let model = config.model.init_with::<B>(record);
|
||||
|
||||
let label = item.label;
|
||||
let batcher = MNISTBatcher::new(device);
|
||||
|
|
|
@ -133,7 +133,7 @@ impl<B: Backend> Model<B> {
|
|||
/// Constructor
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
model: SqueezenetModel::from_embedded(),
|
||||
model: SqueezenetModel::from_embedded(device),
|
||||
normalizer: Normalizer::new(device),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ pub async fn build_and_load_model() -> Model<Backend> {
|
|||
|
||||
let model: Model<Backend> = Model::new(&Default::default());
|
||||
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.load(STATE_ENCODED.to_vec())
|
||||
.load(STATE_ENCODED.to_vec(), &Default::default())
|
||||
.expect("Failed to decode state");
|
||||
|
||||
model.load_record(record)
|
||||
|
|
|
@ -12,7 +12,6 @@ use crate::{
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
module::Module,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
@ -44,7 +43,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
|
|||
// Load pre-trained model weights
|
||||
println!("Loading weights ...");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.load(format!("{artifact_dir}/model").into(), &device)
|
||||
.expect("Trained model weights");
|
||||
|
||||
// Create model using loaded weights
|
||||
|
@ -55,8 +54,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
|
|||
tokenizer.vocab_size(),
|
||||
config.max_seq_length,
|
||||
)
|
||||
.init_with::<B>(record) // Initialize model with loaded weights
|
||||
.to_device(&device); // Move model to computation device
|
||||
.init_with::<B>(record); // Initialize model with loaded weights
|
||||
|
||||
// Run inference on the given text samples
|
||||
println!("Running inference ...");
|
||||
|
|
Loading…
Reference in New Issue