mirror of https://github.com/tracel-ai/burn.git
State serialization/deserialization overhaul (#247)
This commit is contained in:
parent
00625d1527
commit
6f43d983f7
|
@ -42,6 +42,9 @@ ndarray = {version = "0.15.6", default-features = false}
|
|||
num-traits = {version = "0.2.15", default-features = false, features = ["libm"]}# libm is for no_std
|
||||
rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# std_rng is for no_std
|
||||
rand_distr = {version = "0.4.3", default-features = false}
|
||||
uuid = {version = "1.3.0", default-features = false}
|
||||
|
||||
serde = {version = "1.0.155", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed
|
||||
serde_json = {version = "1.0.94", default-features = false}
|
||||
uuid = {version = "1.3.0", default-features = false}
|
||||
rmp-serde = {version = "1.1.1"}
|
||||
bincode = {version = "2.0.0-rc", features=["alloc", "serde"], default-features = false}
|
||||
|
|
|
@ -13,7 +13,6 @@ version = "0.6.0"
|
|||
|
||||
[features]
|
||||
default = ["std"]
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
std = [
|
||||
"burn-autodiff",
|
||||
"burn-common/std",
|
||||
|
@ -24,10 +23,19 @@ std = [
|
|||
"flate2",
|
||||
"log",
|
||||
"rand/std",
|
||||
"serde_json/std",
|
||||
"serde/std",
|
||||
"serde_json/std",
|
||||
"bincode/std",
|
||||
"half/std",
|
||||
"half/serde", # TODO: set default when https://github.com/starkat99/half-rs/issues/84 is fixed
|
||||
]
|
||||
|
||||
# Serialization formats
|
||||
msgpack = ["rmp-serde"] # Assumes std
|
||||
|
||||
test-tch = [] # To use tch during testing, default uses ndarray.
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
|
||||
|
@ -52,7 +60,11 @@ hashbrown = {workspace = true, features = ["serde"]}# no_std compatible
|
|||
# Serialize Deserialize
|
||||
flate2 = {workspace = true, optional = true}
|
||||
serde = {workspace = true, features = ["derive"]}
|
||||
|
||||
serde_json = {workspace = true, features = ["alloc"]}#Default enables std
|
||||
rmp-serde = {workspace = true, optional = true}
|
||||
bincode = {workspace = true}
|
||||
half = {workspace = true}
|
||||
|
||||
[dev-dependencies]
|
||||
burn-dataset = {path = "../burn-dataset", version = "0.6.0", features = [
|
||||
|
|
|
@ -1,21 +1,17 @@
|
|||
use alloc::{
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use super::ParamId;
|
||||
use crate::tensor::{DataSerialize, Element};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::{collections::HashMap, fs::File, path::Path};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use hashbrown::HashMap;
|
||||
#[cfg(feature = "std")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct StateNamed<E> {
|
||||
|
@ -35,6 +31,23 @@ pub enum StateError {
|
|||
FileNotFound(String),
|
||||
}
|
||||
|
||||
/// All supported format.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The default file format is compressed bincode for the smallest file size possible.
|
||||
/// For `no_std` environments, you should use (StateFormat::Bin)[StateFormat::Bin] since compression isn't supported.
|
||||
/// However, the bincode format alone is smaller than compressed `json` or `msgpack`.
|
||||
#[derive(Default, Clone)]
|
||||
pub enum StateFormat {
|
||||
#[default]
|
||||
BinGz,
|
||||
Bin,
|
||||
JsonGz,
|
||||
#[cfg(feature = "msgpack")]
|
||||
MpkGz,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for StateError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let mut message = "State error => ".to_string();
|
||||
|
@ -52,10 +65,6 @@ impl core::fmt::Display for StateError {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for StateError {}
|
||||
|
||||
impl<E: Element> StateNamed<E> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
@ -113,41 +122,167 @@ impl<E: Element> State<E> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl<E: Element> State<E>
|
||||
where
|
||||
E: serde::de::DeserializeOwned,
|
||||
E: serde::Serialize,
|
||||
{
|
||||
pub fn save(self, file: &str) -> std::io::Result<()> {
|
||||
let path = Path::new(file);
|
||||
if path.exists() {
|
||||
log::info!("File exists, replacing");
|
||||
std::fs::remove_file(path).unwrap();
|
||||
pub fn to_bin(&self) -> Result<Vec<u8>, StateError> {
|
||||
Ok(bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap())
|
||||
}
|
||||
|
||||
pub fn from_bin(data: &[u8]) -> Result<Self, StateError> {
|
||||
let state = bincode::serde::decode_borrowed_from_slice(data, Self::bin_config()).unwrap();
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn bin_config() -> bincode::config::Configuration {
|
||||
bincode::config::standard()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod std_enabled {
|
||||
use super::*;
|
||||
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||
use std::{fs::File, path::Path};
|
||||
|
||||
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
|
||||
impl std::error::Error for StateError {}
|
||||
|
||||
macro_rules! str2reader {
|
||||
(
|
||||
$file:expr,
|
||||
$ext:expr
|
||||
) => {{
|
||||
let path_ref = &format!("{}.{}", $file, $ext);
|
||||
let path = Path::new(path_ref);
|
||||
|
||||
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! str2writer {
|
||||
(
|
||||
$file:expr,
|
||||
$ext:expr
|
||||
) => {{
|
||||
let path_ref = &format!("{}.{}", $file, $ext);
|
||||
let path = Path::new(path_ref);
|
||||
if path.exists() {
|
||||
log::info!("File exists, replacing");
|
||||
std::fs::remove_file(path).unwrap();
|
||||
}
|
||||
|
||||
File::create(path)
|
||||
}};
|
||||
}
|
||||
impl<E: Element> State<E>
|
||||
where
|
||||
E: serde::de::DeserializeOwned,
|
||||
E: serde::Serialize,
|
||||
{
|
||||
/// Save the state to the provided file path using the given [StateFormat](StateFormat).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The file extension will be added automatically depending on the state format.
|
||||
pub fn save(self, file: &str, format: &StateFormat) -> std::io::Result<()> {
|
||||
match format {
|
||||
StateFormat::BinGz => self.save_bingz(file),
|
||||
StateFormat::Bin => self.save_bin(file),
|
||||
StateFormat::JsonGz => self.save_jsongz(file),
|
||||
#[cfg(feature = "msgpack")]
|
||||
StateFormat::MpkGz => self.save_mpkgz(file),
|
||||
}
|
||||
}
|
||||
|
||||
let writer = File::create(path)?;
|
||||
let writer = GzEncoder::new(writer, Compression::default());
|
||||
serde_json::to_writer(writer, &self).unwrap();
|
||||
/// Load the state from the provided file path using the given [StateFormat](StateFormat).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The file extension will be added automatically depending on the state format.
|
||||
pub fn load(file: &str, format: &StateFormat) -> Result<Self, StateError> {
|
||||
match format {
|
||||
StateFormat::BinGz => Self::load_bingz(file),
|
||||
StateFormat::Bin => Self::load_bin(file),
|
||||
StateFormat::JsonGz => Self::load_jsongz(file),
|
||||
#[cfg(feature = "msgpack")]
|
||||
StateFormat::MpkGz => Self::load_mpkgz(file),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
fn save_jsongz(self, file: &str) -> std::io::Result<()> {
|
||||
let writer = str2writer!(file, "json.gz")?;
|
||||
let writer = GzEncoder::new(writer, Compression::default());
|
||||
serde_json::to_writer(writer, &self).unwrap();
|
||||
|
||||
pub fn load(file: &str) -> Result<Self, StateError> {
|
||||
let path = Path::new(file);
|
||||
let reader =
|
||||
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = serde_json::from_reader(reader).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
fn load_jsongz(file: &str) -> Result<Self, StateError> {
|
||||
let reader = str2reader!(file, "json.gz")?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = serde_json::from_reader(reader).unwrap();
|
||||
|
||||
pub fn load_binary(data: &[u8]) -> Result<Self, StateError> {
|
||||
let reader = GzDecoder::new(data);
|
||||
let state = serde_json::from_reader(reader).unwrap();
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
#[cfg(feature = "msgpack")]
|
||||
fn save_mpkgz(self, file: &str) -> std::io::Result<()> {
|
||||
let writer = str2writer!(file, "mpk.gz")?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
rmp_serde::encode::write(&mut writer, &self).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
fn load_mpkgz(file: &str) -> Result<Self, StateError> {
|
||||
let reader = str2reader!(file, "mpk.gz")?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = rmp_serde::decode::from_read(reader).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn save_bingz(self, file: &str) -> std::io::Result<()> {
|
||||
let config = Self::bin_config();
|
||||
let writer = str2writer!(file, "bin.gz")?;
|
||||
let mut writer = GzEncoder::new(writer, Compression::default());
|
||||
|
||||
bincode::serde::encode_into_std_write(&self, &mut writer, config).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_bingz(file: &str) -> Result<Self, StateError> {
|
||||
let reader = str2reader!(file, "bin.gz")?;
|
||||
let mut reader = GzDecoder::new(reader);
|
||||
let state =
|
||||
bincode::serde::decode_from_std_read(&mut reader, Self::bin_config()).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn save_bin(self, file: &str) -> std::io::Result<()> {
|
||||
let buf = bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap();
|
||||
|
||||
let mut writer = str2writer!(file, "bin")?;
|
||||
std::io::Write::write_all(&mut writer, &buf).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_bin(file: &str) -> Result<Self, StateError> {
|
||||
let mut reader = str2reader!(file, "bin")?;
|
||||
let mut buf = Vec::new();
|
||||
std::io::Read::read_to_end(&mut reader, &mut buf).unwrap();
|
||||
let state =
|
||||
bincode::serde::decode_borrowed_from_slice(&buf, Self::bin_config()).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,21 +305,6 @@ mod tests {
|
|||
assert_eq!(state, state_from);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file() {
|
||||
let model_before = create_model();
|
||||
let state_before = model_before.state();
|
||||
state_before.clone().save("/tmp/test.json").unwrap();
|
||||
|
||||
let model_after = create_model()
|
||||
.load(&State::load("/tmp/test.json").unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state_after = model_after.state();
|
||||
assert_eq!(state_before, state_after);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_ids_are_loaded() {
|
||||
let model_1 = create_model();
|
||||
|
@ -200,33 +320,80 @@ mod tests {
|
|||
assert_eq!(params_before_1, params_after_2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_load_binary() {
|
||||
fn test_from_to_binary() {
|
||||
let model_1 = create_model();
|
||||
let mut model_2 = create_model();
|
||||
let model_2 = create_model();
|
||||
let params_before_1 = list_param_ids(&model_1);
|
||||
let params_before_2 = list_param_ids(&model_2);
|
||||
|
||||
// Write to binary.
|
||||
|
||||
let state = model_1.state();
|
||||
let mut binary = Vec::new();
|
||||
let writer = GzEncoder::new(&mut binary, Compression::default());
|
||||
serde_json::to_writer(writer, &state).unwrap();
|
||||
|
||||
// Load.
|
||||
|
||||
model_2 = model_2.load(&State::load_binary(&binary).unwrap()).unwrap();
|
||||
let params_after_2 = list_param_ids(&model_2);
|
||||
// To & From Bytes
|
||||
let bytes = model_1.state().to_bin().unwrap();
|
||||
let model_2 = model_2.load(&State::from_bin(&bytes).unwrap()).unwrap();
|
||||
|
||||
// Verify.
|
||||
|
||||
let params_after_2 = list_param_ids(&model_2);
|
||||
assert_ne!(params_before_1, params_before_2);
|
||||
assert_eq!(params_before_1, params_after_2);
|
||||
}
|
||||
|
||||
fn create_model() -> nn::Linear<TestBackend> {
|
||||
pub fn create_model() -> nn::Linear<TestBackend> {
|
||||
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig::new(32, 32).with_bias(true))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests_save_load {
|
||||
use super::tests::create_model;
|
||||
use super::*;
|
||||
use crate::module::Module;
|
||||
|
||||
static FILE_PATH: &str = "/tmp/test_state";
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_jsongz_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::JsonGz)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_bin_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::Bin)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_bingz_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::BinGz)
|
||||
}
|
||||
|
||||
#[cfg(feature = "msgpack")]
|
||||
#[test]
|
||||
fn test_can_save_and_load_from_file_mpkgz_format() {
|
||||
test_can_save_and_load_from_file(StateFormat::MpkGz)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_bin_on_disk() {
|
||||
let model = create_model();
|
||||
model
|
||||
.state()
|
||||
.save("/tmp/model_compare", &StateFormat::Bin)
|
||||
.unwrap();
|
||||
let bytes = std::fs::read("/tmp/model_compare.bin").unwrap();
|
||||
let state = State::from_bin(&bytes).unwrap();
|
||||
|
||||
assert_eq!(state, model.state());
|
||||
}
|
||||
|
||||
fn test_can_save_and_load_from_file(format: StateFormat) {
|
||||
let model_before = create_model();
|
||||
let state_before = model_before.state();
|
||||
state_before.clone().save(FILE_PATH, &format).unwrap();
|
||||
|
||||
let model_after = create_model()
|
||||
.load(&State::load(FILE_PATH, &format).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state_after = model_after.state();
|
||||
assert_eq!(state_before, state_after);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,9 @@ use alloc::vec::Vec;
|
|||
use core::cmp::Ordering;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::element::FloatNdArrayElement;
|
||||
// Current crate
|
||||
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
|
||||
use crate::element::FloatNdArrayElement;
|
||||
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
||||
use crate::{NdArrayDevice, SEED};
|
||||
|
||||
|
@ -15,10 +16,9 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
|
|||
|
||||
// External crates
|
||||
use libm::{cos, erf, sin, tanh};
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float; // Can't compare two floats with no_std.
|
||||
|
||||
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float;
|
||||
|
||||
impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
|
||||
|
|
|
@ -17,7 +17,7 @@ doc = ["tch/doc-only"]
|
|||
burn-tensor = {path = "../burn-tensor", version = "0.6.0"}
|
||||
libc = "0.2.0"
|
||||
|
||||
half = {workspace = true}
|
||||
half = {workspace = true, features = ["std"]}
|
||||
rand = {workspace = true, features = ["std"]}
|
||||
|
||||
[target.'cfg(not(target_arch = "aarch64"))'.dependencies]
|
||||
|
|
|
@ -19,6 +19,8 @@ experimental-named-tensor = []
|
|||
export_tests = ["burn-tensor-testgen"]
|
||||
std = [
|
||||
"rand/std",
|
||||
"half/std",
|
||||
"half/serde", # TODO: set default when https://github.com/starkat99/half-rs/issues/84 is fixed
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
|
|
|
@ -7,55 +7,55 @@ use libm::{pow, round};
|
|||
use rand::{distributions::Standard, Rng, RngCore};
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone)]
|
||||
pub struct DataSerialize<P> {
|
||||
pub value: Vec<P>,
|
||||
pub struct DataSerialize<E> {
|
||||
pub value: Vec<E>,
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Data<P, const D: usize> {
|
||||
pub value: Vec<P>,
|
||||
pub struct Data<E, const D: usize> {
|
||||
pub value: Vec<E>,
|
||||
pub shape: Shape<D>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Distribution<P> {
|
||||
pub enum Distribution<E> {
|
||||
Standard,
|
||||
Bernoulli(f64),
|
||||
Uniform(P, P),
|
||||
Uniform(E, E),
|
||||
Normal(f64, f64),
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct DistributionSampler<'a, P, R>
|
||||
pub struct DistributionSampler<'a, E, R>
|
||||
where
|
||||
Standard: rand::distributions::Distribution<P>,
|
||||
P: rand::distributions::uniform::SampleUniform,
|
||||
Standard: rand::distributions::Distribution<E>,
|
||||
E: rand::distributions::uniform::SampleUniform,
|
||||
R: RngCore,
|
||||
{
|
||||
kind: DistributionSamplerKind<P>,
|
||||
kind: DistributionSamplerKind<E>,
|
||||
rng: &'a mut R,
|
||||
}
|
||||
|
||||
pub enum DistributionSamplerKind<P>
|
||||
pub enum DistributionSamplerKind<E>
|
||||
where
|
||||
Standard: rand::distributions::Distribution<P>,
|
||||
P: rand::distributions::uniform::SampleUniform,
|
||||
Standard: rand::distributions::Distribution<E>,
|
||||
E: rand::distributions::uniform::SampleUniform,
|
||||
{
|
||||
Standard(rand::distributions::Standard),
|
||||
Uniform(rand::distributions::Uniform<P>),
|
||||
Uniform(rand::distributions::Uniform<E>),
|
||||
Bernoulli(rand::distributions::Bernoulli),
|
||||
Normal(rand_distr::Normal<f64>),
|
||||
}
|
||||
|
||||
impl<'a, P, R> DistributionSampler<'a, P, R>
|
||||
impl<'a, E, R> DistributionSampler<'a, E, R>
|
||||
where
|
||||
Standard: rand::distributions::Distribution<P>,
|
||||
P: rand::distributions::uniform::SampleUniform,
|
||||
P: Element,
|
||||
Standard: rand::distributions::Distribution<E>,
|
||||
E: rand::distributions::uniform::SampleUniform,
|
||||
E: Element,
|
||||
R: RngCore,
|
||||
{
|
||||
pub fn sample(&mut self) -> P {
|
||||
pub fn sample(&mut self) -> E {
|
||||
match &self.kind {
|
||||
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
|
||||
DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
|
||||
|
@ -71,12 +71,12 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<P> Distribution<P>
|
||||
impl<E> Distribution<E>
|
||||
where
|
||||
Standard: rand::distributions::Distribution<P>,
|
||||
P: rand::distributions::uniform::SampleUniform,
|
||||
Standard: rand::distributions::Distribution<E>,
|
||||
E: rand::distributions::uniform::SampleUniform,
|
||||
{
|
||||
pub fn sampler<R: RngCore>(self, rng: &'_ mut R) -> DistributionSampler<'_, P, R> {
|
||||
pub fn sampler<R: RngCore>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> {
|
||||
let kind = match self {
|
||||
Distribution::Standard => {
|
||||
DistributionSamplerKind::Standard(rand::distributions::Standard {})
|
||||
|
@ -96,23 +96,25 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<P> Distribution<P>
|
||||
impl<E> Distribution<E>
|
||||
where
|
||||
P: Element,
|
||||
E: Element,
|
||||
{
|
||||
pub fn convert<E: Element>(self) -> Distribution<E> {
|
||||
pub fn convert<EOther: Element>(self) -> Distribution<EOther> {
|
||||
match self {
|
||||
Distribution::Standard => Distribution::Standard,
|
||||
Distribution::Uniform(a, b) => Distribution::Uniform(E::from_elem(a), E::from_elem(b)),
|
||||
Distribution::Uniform(a, b) => {
|
||||
Distribution::Uniform(EOther::from_elem(a), EOther::from_elem(b))
|
||||
}
|
||||
Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob),
|
||||
Distribution::Normal(mean, std) => Distribution::Normal(mean, std),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, P: Element> Data<P, D> {
|
||||
pub fn convert<E: Element>(self) -> Data<E, D> {
|
||||
let value: Vec<E> = self.value.into_iter().map(|a| a.elem()).collect();
|
||||
impl<const D: usize, E: Element> Data<E, D> {
|
||||
pub fn convert<EOther: Element>(self) -> Data<EOther, D> {
|
||||
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
|
||||
|
||||
Data {
|
||||
value,
|
||||
|
@ -121,9 +123,9 @@ impl<const D: usize, P: Element> Data<P, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: Element> DataSerialize<P> {
|
||||
pub fn convert<E: Element>(self) -> DataSerialize<E> {
|
||||
let value: Vec<E> = self.value.into_iter().map(|a| a.elem()).collect();
|
||||
impl<E: Element> DataSerialize<E> {
|
||||
pub fn convert<EOther: Element>(self) -> DataSerialize<EOther> {
|
||||
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
|
||||
|
||||
DataSerialize {
|
||||
value,
|
||||
|
@ -142,23 +144,23 @@ impl<const D: usize> Data<bool, D> {
|
|||
}
|
||||
}
|
||||
}
|
||||
impl<P: Element, const D: usize> Data<P, D> {
|
||||
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution<P>, rng: &mut R) -> Self {
|
||||
impl<E: Element, const D: usize> Data<E, D> {
|
||||
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution<E>, rng: &mut R) -> Self {
|
||||
let num_elements = shape.num_elements();
|
||||
let mut data = Vec::with_capacity(num_elements);
|
||||
|
||||
for _ in 0..num_elements {
|
||||
data.push(P::random(distribution, rng));
|
||||
data.push(E::random(distribution, rng));
|
||||
}
|
||||
|
||||
Data::new(data, shape)
|
||||
}
|
||||
}
|
||||
impl<P: core::fmt::Debug, const D: usize> Data<P, D>
|
||||
impl<E: core::fmt::Debug, const D: usize> Data<E, D>
|
||||
where
|
||||
P: Element,
|
||||
E: Element,
|
||||
{
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<P, D> {
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<E, D> {
|
||||
let shape = shape.into();
|
||||
let num_elements = shape.num_elements();
|
||||
let mut data = Vec::with_capacity(num_elements);
|
||||
|
@ -169,16 +171,16 @@ where
|
|||
|
||||
Data::new(data, shape)
|
||||
}
|
||||
pub fn zeros_(shape: Shape<D>, _kind: P) -> Data<P, D> {
|
||||
pub fn zeros_(shape: Shape<D>, _kind: E) -> Data<E, D> {
|
||||
Self::zeros(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug, const D: usize> Data<P, D>
|
||||
impl<E: core::fmt::Debug, const D: usize> Data<E, D>
|
||||
where
|
||||
P: Element,
|
||||
E: Element,
|
||||
{
|
||||
pub fn ones(shape: Shape<D>) -> Data<P, D> {
|
||||
pub fn ones(shape: Shape<D>) -> Data<E, D> {
|
||||
let num_elements = shape.num_elements();
|
||||
let mut data = Vec::with_capacity(num_elements);
|
||||
|
||||
|
@ -188,13 +190,13 @@ where
|
|||
|
||||
Data::new(data, shape)
|
||||
}
|
||||
pub fn ones_(shape: Shape<D>, _kind: P) -> Data<P, D> {
|
||||
pub fn ones_(shape: Shape<D>, _kind: E) -> Data<E, D> {
|
||||
Self::ones(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug + Copy, const D: usize> Data<P, D> {
|
||||
pub fn serialize(&self) -> DataSerialize<P> {
|
||||
impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
|
||||
pub fn serialize(&self) -> DataSerialize<E> {
|
||||
DataSerialize {
|
||||
value: self.value.clone(),
|
||||
shape: self.shape.dims.to_vec(),
|
||||
|
@ -202,7 +204,7 @@ impl<P: core::fmt::Debug + Copy, const D: usize> Data<P, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<P, D> {
|
||||
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E, D> {
|
||||
pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
|
||||
assert_eq!(self.shape, other.shape);
|
||||
|
||||
|
@ -246,24 +248,24 @@ impl<const D: usize> Data<usize, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: Clone, const D: usize> From<&DataSerialize<P>> for Data<P, D> {
|
||||
fn from(data: &DataSerialize<P>) -> Self {
|
||||
impl<E: Clone, const D: usize> From<&DataSerialize<E>> for Data<E, D> {
|
||||
fn from(data: &DataSerialize<E>) -> Self {
|
||||
let mut dims = [0; D];
|
||||
dims[..D].copy_from_slice(&data.shape[..D]);
|
||||
Data::new(data.value.clone(), Shape::new(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> From<DataSerialize<P>> for Data<P, D> {
|
||||
fn from(data: DataSerialize<P>) -> Self {
|
||||
impl<E, const D: usize> From<DataSerialize<E>> for Data<E, D> {
|
||||
fn from(data: DataSerialize<E>) -> Self {
|
||||
let mut dims = [0; D];
|
||||
dims[..D].copy_from_slice(&data.shape[..D]);
|
||||
Data::new(data.value, Shape::new(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug + Copy, const A: usize> From<[P; A]> for Data<P, 1> {
|
||||
fn from(elems: [P; A]) -> Self {
|
||||
impl<E: core::fmt::Debug + Copy, const A: usize> From<[E; A]> for Data<E, 1> {
|
||||
fn from(elems: [E; A]) -> Self {
|
||||
let mut data = Vec::with_capacity(2 * A);
|
||||
for elem in elems.into_iter() {
|
||||
data.push(elem);
|
||||
|
@ -273,8 +275,8 @@ impl<P: core::fmt::Debug + Copy, const A: usize> From<[P; A]> for Data<P, 1> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug + Copy> From<&[P]> for Data<P, 1> {
|
||||
fn from(elems: &[P]) -> Self {
|
||||
impl<E: core::fmt::Debug + Copy> From<&[E]> for Data<E, 1> {
|
||||
fn from(elems: &[E]) -> Self {
|
||||
let mut data = Vec::with_capacity(elems.len());
|
||||
for elem in elems.iter() {
|
||||
data.push(*elem);
|
||||
|
@ -284,8 +286,8 @@ impl<P: core::fmt::Debug + Copy> From<&[P]> for Data<P, 1> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[P; B]; A]> for Data<P, 2> {
|
||||
fn from(elems: [[P; B]; A]) -> Self {
|
||||
impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[E; B]; A]> for Data<E, 2> {
|
||||
fn from(elems: [[E; B]; A]) -> Self {
|
||||
let mut data = Vec::with_capacity(A * B);
|
||||
for elem in elems.into_iter().take(A) {
|
||||
for elem in elem.into_iter().take(B) {
|
||||
|
@ -297,10 +299,10 @@ impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[P; B]; A
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
|
||||
From<[[[P; C]; B]; A]> for Data<P, 3>
|
||||
impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
|
||||
From<[[[E; C]; B]; A]> for Data<E, 3>
|
||||
{
|
||||
fn from(elems: [[[P; C]; B]; A]) -> Self {
|
||||
fn from(elems: [[[E; C]; B]; A]) -> Self {
|
||||
let mut data = Vec::with_capacity(A * B * C);
|
||||
|
||||
for elem in elems.into_iter().take(A) {
|
||||
|
@ -316,14 +318,14 @@ impl<P: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
|
|||
}
|
||||
|
||||
impl<
|
||||
P: core::fmt::Debug + Copy,
|
||||
E: core::fmt::Debug + Copy,
|
||||
const A: usize,
|
||||
const B: usize,
|
||||
const C: usize,
|
||||
const D: usize,
|
||||
> From<[[[[P; D]; C]; B]; A]> for Data<P, 4>
|
||||
> From<[[[[E; D]; C]; B]; A]> for Data<E, 4>
|
||||
{
|
||||
fn from(elems: [[[[P; D]; C]; B]; A]) -> Self {
|
||||
fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
|
||||
let mut data = Vec::with_capacity(A * B * C * D);
|
||||
|
||||
for elem in elems.into_iter().take(A) {
|
||||
|
@ -340,11 +342,12 @@ impl<
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: core::fmt::Debug, const D: usize> core::fmt::Display for Data<P, D> {
|
||||
impl<E: core::fmt::Debug, const D: usize> core::fmt::Display for Data<E, D> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("{:?}", &self.value).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,27 +1,29 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::module::State;
|
||||
use burn_core::module::{State, StateFormat};
|
||||
use burn_core::tensor::Element;
|
||||
|
||||
pub struct FileCheckpointer<P> {
|
||||
directory: String,
|
||||
name: String,
|
||||
num_keep: usize,
|
||||
format: StateFormat,
|
||||
_precision: P,
|
||||
}
|
||||
|
||||
impl<P: Element> FileCheckpointer<P> {
|
||||
pub fn new(directory: &str, name: &str, num_keep: usize) -> Self {
|
||||
pub fn new(directory: &str, name: &str, num_keep: usize, format: StateFormat) -> Self {
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
|
||||
Self {
|
||||
directory: directory.to_string(),
|
||||
name: name.to_string(),
|
||||
num_keep,
|
||||
format,
|
||||
_precision: P::default(),
|
||||
}
|
||||
}
|
||||
fn path_for_epoch(&self, epoch: usize) -> String {
|
||||
format!("{}/{}-{}.json.gz", self.directory, self.name, epoch)
|
||||
format!("{}/{}-{}", self.directory, self.name, epoch)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,7 +38,7 @@ where
|
|||
|
||||
state
|
||||
.convert::<P>()
|
||||
.save(&file_path)
|
||||
.save(&file_path, &self.format)
|
||||
.map_err(CheckpointerError::IOError)?;
|
||||
|
||||
if self.num_keep > epoch {
|
||||
|
@ -57,7 +59,8 @@ where
|
|||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
|
||||
|
||||
let state = State::<P>::load(&file_path).map_err(CheckpointerError::StateError)?;
|
||||
let state =
|
||||
State::<P>::load(&file_path, &self.format).map_err(CheckpointerError::StateError)?;
|
||||
|
||||
Ok(state.convert())
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::metric::dashboard::cli::CLIDashboardRenderer;
|
|||
use crate::metric::dashboard::Dashboard;
|
||||
use crate::metric::{Adaptor, Metric, Numeric};
|
||||
use crate::AsyncTrainerCallback;
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::module::{ADModule, StateFormat};
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::tensor::backend::ADBackend;
|
||||
use burn_core::tensor::Element;
|
||||
|
@ -143,16 +143,19 @@ where
|
|||
pub fn with_file_checkpointer<P: Element + serde::de::DeserializeOwned + serde::Serialize>(
|
||||
mut self,
|
||||
num_keep: usize,
|
||||
format: StateFormat,
|
||||
) -> Self {
|
||||
self.checkpointer_model = Some(Arc::new(FileCheckpointer::<P>::new(
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"model",
|
||||
num_keep,
|
||||
format.clone(),
|
||||
)));
|
||||
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::<P>::new(
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"optim",
|
||||
num_keep,
|
||||
format,
|
||||
)));
|
||||
self
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ std = [
|
|||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||
train = ["std", "burn-train"] # Training requires std
|
||||
|
||||
msgpack = ["burn-core/msgpack"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
|
|
@ -15,10 +15,6 @@ default = []
|
|||
[dependencies]
|
||||
burn = {path = "../../burn", default-features = false}
|
||||
burn-ndarray = {path = "../../burn-ndarray", default-features = false}
|
||||
|
||||
# 2.0 supports no_std and serder
|
||||
bincode = {version = "2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", default-features = false, features = ["alloc", "serde"]}
|
||||
|
||||
serde = {workspace = true}
|
||||
wasm-bindgen = "0.2.84"
|
||||
|
||||
|
@ -26,6 +22,5 @@ wasm-bindgen = "0.2.84"
|
|||
burn-dataset = {path = "../../burn-dataset"}
|
||||
|
||||
[build-dependencies]
|
||||
bincode = {version = "2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", default-features = false, features = ["alloc", "serde"]}
|
||||
burn = {path = "../../burn"}
|
||||
serde = {workspace = true}
|
||||
|
|
|
@ -27,14 +27,9 @@ makes it possible to build and run the model with the `wasm32-unknown-unknown` t
|
|||
special system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to
|
||||
include burn dependencies without `std`).
|
||||
|
||||
For this demo, we use trained parameters (`model-4.json.gz`) and model (`model.rs`) from the
|
||||
For this demo, we use trained parameters (`model.bin`) and model (`model.rs`) from the
|
||||
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
|
||||
|
||||
During the build time `model-4.json.gz` is converted to
|
||||
[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the
|
||||
final wasm output. The MNIST model is initialized with trained weights from memory during the
|
||||
runtime.
|
||||
|
||||
The inference API for JavaScript is exposed with the help of
|
||||
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.
|
||||
|
||||
|
@ -75,8 +70,6 @@ byte file is the model's parameters. The rest of 356,744 bytes contain all the c
|
|||
|
||||
There are several planned enhancements in place:
|
||||
|
||||
- [#201](https://github.com/burn-rs/burn/issues/201) - Saving model's params in binary format. This
|
||||
will simplify the inference code.
|
||||
- [#202](https://github.com/burn-rs/burn/issues/202) - Saving model's params in half-precision and
|
||||
loading back in full. This can be half the size of the wasm file.
|
||||
- [#243](https://github.com/burn-rs/burn/issues/243) - New WebGPU backend would allow computation
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
use std::fs;
|
||||
|
||||
use burn::module::State;
|
||||
|
||||
use bincode::config;
|
||||
|
||||
const GENERATED_FILE_NAME: &str = "mnist_model_state.bincode";
|
||||
const MODEL_STATE_FILE_NAME: &str = "model-4.json.gz";
|
||||
|
||||
/// This build step is responsible for converting JSON serialized to Bincode serilization
|
||||
/// in order to make the file small and efficient for bundling the binary into wasm code.
|
||||
///
|
||||
/// This will be removed once https://github.com/burn-rs/burn/issues/201 is resolved.
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed={MODEL_STATE_FILE_NAME}");
|
||||
let config = config::standard();
|
||||
let path: std::path::PathBuf = [
|
||||
std::env::var("OUT_DIR").expect("No build target path set"),
|
||||
GENERATED_FILE_NAME.into(),
|
||||
]
|
||||
.iter()
|
||||
.collect();
|
||||
let state: State<f32> =
|
||||
State::load(MODEL_STATE_FILE_NAME).expect(concat!("Model JSON file could not be loaded"));
|
||||
let serialized =
|
||||
bincode::serde::encode_to_vec(state, config).expect("Encoding state into bincode failed");
|
||||
fs::write(path, serialized).expect("Write failed");
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -4,27 +4,14 @@ use burn::module::Module;
|
|||
use burn::module::State;
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
|
||||
use bincode::{
|
||||
config::{self, Configuration},
|
||||
serde::decode_from_slice,
|
||||
};
|
||||
|
||||
pub type Backend = NdArrayBackend<f32>;
|
||||
|
||||
const BINCODE_CONF: Configuration = config::standard();
|
||||
|
||||
// Bundled bincode serialized model state object
|
||||
// see https://github.com/bincode-org/bincode and https://doc.rust-lang.org/std/macro.include_bytes.html
|
||||
static STATE_ENCODED: &[u8] =
|
||||
include_bytes!(concat!(env!("OUT_DIR"), "/mnist_model_state.bincode"));
|
||||
static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
|
||||
|
||||
/// Builds and loads trained parameters into the model.
|
||||
pub fn build_and_load_model() -> Model<Backend> {
|
||||
let model: Model<Backend> = Model::new();
|
||||
|
||||
// TODO: fix forward slash to make the paths work in windows
|
||||
let (state, _len): (State<f32>, usize) =
|
||||
decode_from_slice(STATE_ENCODED, BINCODE_CONF).expect("Failed to decode state");
|
||||
let state: State<f32> = State::from_bin(STATE_ENCODED).expect("Failed to decode state");
|
||||
|
||||
model
|
||||
.load(&state)
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||
use crate::data::MNISTBatcher;
|
||||
use crate::model::Model;
|
||||
|
||||
use burn::module::{Module, StateFormat};
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::{Adam, AdamConfig};
|
||||
use burn::{
|
||||
|
@ -64,14 +65,21 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<f32>(2)
|
||||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
||||
let _model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
|
||||
.unwrap();
|
||||
|
||||
// We save a bin version of the model to be loaded with no_std environement.
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{ARTIFACT_DIR}/model"), &StateFormat::Bin)
|
||||
.expect("Failed to save trained model");
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
module::Module,
|
||||
module::{Module, StateFormat},
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::{Sgd, SgdConfig},
|
||||
tensor::backend::ADBackend,
|
||||
|
@ -78,7 +78,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<f32>(2)
|
||||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
@ -86,9 +86,10 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config.save(&format!("{artifact_dir}/config.json")).unwrap();
|
||||
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{artifact_dir}/model.json.gz"))
|
||||
.convert::<burn::tensor::f16>()
|
||||
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ use crate::{
|
|||
data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer},
|
||||
model::{TextGenerationModel, TextGenerationModelConfig},
|
||||
};
|
||||
use burn::data::dataset::transform::SamplerDataset;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
|
||||
|
@ -15,6 +14,7 @@ use burn::{
|
|||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
use burn::{data::dataset::transform::SamplerDataset, module::StateFormat};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Config)]
|
||||
|
@ -75,7 +75,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<f32>(2)
|
||||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
|
||||
.devices(vec![device])
|
||||
.grads_accumulation(16)
|
||||
.num_epochs(config.num_epochs)
|
||||
|
@ -84,9 +84,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config.save(&format!("{artifact_dir}/config.json")).unwrap();
|
||||
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{artifact_dir}/model.json.gz"))
|
||||
.convert::<burn::tensor::f16>()
|
||||
.save(&format!("{artifact_dir}/model"), &StateFormat::default())
|
||||
.unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue