mirror of https://github.com/tracel-ai/burn.git
Add feature to load config / state from binary slice. (#164)
This commit is contained in:
parent
2401d8ad96
commit
f907d0025d
|
@ -34,6 +34,13 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
|
||||||
.map_err(|_| ConfigError::FileNotFound(file.to_string()))?;
|
.map_err(|_| ConfigError::FileNotFound(file.to_string()))?;
|
||||||
config_from_str(&content)
|
config_from_str(&content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
|
||||||
|
let content = std::str::from_utf8(data).map_err(|_| {
|
||||||
|
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
|
||||||
|
})?;
|
||||||
|
config_from_str(content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn config_to_json<C: Config>(config: &C) -> String {
|
pub fn config_to_json<C: Config>(config: &C) -> String {
|
||||||
|
|
|
@ -128,6 +128,13 @@ where
|
||||||
|
|
||||||
Ok(state)
|
Ok(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_binary(data: &[u8]) -> Result<Self, StateError> {
|
||||||
|
let reader = GzDecoder::new(data);
|
||||||
|
let state = serde_json::from_reader(reader).unwrap();
|
||||||
|
|
||||||
|
Ok(state)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -179,6 +186,31 @@ mod tests {
|
||||||
assert_eq!(params_before_1, params_after_2);
|
assert_eq!(params_before_1, params_after_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_binary() {
|
||||||
|
let model_1 = create_model();
|
||||||
|
let mut model_2 = create_model();
|
||||||
|
let params_before_1 = list_param_ids(&model_1);
|
||||||
|
let params_before_2 = list_param_ids(&model_2);
|
||||||
|
|
||||||
|
// Write to binary.
|
||||||
|
|
||||||
|
let state = model_1.state();
|
||||||
|
let mut binary = Vec::new();
|
||||||
|
let writer = GzEncoder::new(&mut binary, Compression::default());
|
||||||
|
serde_json::to_writer(writer, &state).unwrap();
|
||||||
|
|
||||||
|
// Load.
|
||||||
|
|
||||||
|
model_2.load(&State::load_binary(&binary).unwrap()).unwrap();
|
||||||
|
let params_after_2 = list_param_ids(&model_2);
|
||||||
|
|
||||||
|
// Verify.
|
||||||
|
|
||||||
|
assert_ne!(params_before_1, params_before_2);
|
||||||
|
assert_eq!(params_before_1, params_after_2);
|
||||||
|
}
|
||||||
|
|
||||||
fn create_model() -> nn::Linear<TestBackend> {
|
fn create_model() -> nn::Linear<TestBackend> {
|
||||||
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
|
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
|
||||||
d_input: 32,
|
d_input: 32,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use burn::config::Config;
|
use burn::config::{config_to_json, Config};
|
||||||
use burn_core as burn;
|
use burn_core as burn;
|
||||||
|
|
||||||
#[derive(Config, Debug, PartialEq, Eq)]
|
#[derive(Config, Debug, PartialEq, Eq)]
|
||||||
|
@ -90,3 +90,13 @@ fn enum_config_should_impl_display() {
|
||||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn struct_config_can_load_binary() {
|
||||||
|
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
|
||||||
|
|
||||||
|
let binary = config_to_json(&config).as_bytes().to_vec();
|
||||||
|
|
||||||
|
let config_loaded = TestStructConfig::load_binary(&binary).unwrap();
|
||||||
|
assert_eq!(config, config_loaded);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue