Add feature to load config / state from binary slice. (#164)

This commit is contained in:
Aaron Roney 2023-02-16 05:43:43 -08:00 committed by GitHub
parent 2401d8ad96
commit f907d0025d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 1 deletions

View File

@ -34,6 +34,13 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
.map_err(|_| ConfigError::FileNotFound(file.to_string()))?;
config_from_str(&content)
}
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
let content = std::str::from_utf8(data).map_err(|_| {
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
})?;
config_from_str(content)
}
}
pub fn config_to_json<C: Config>(config: &C) -> String {

View File

@ -128,6 +128,13 @@ where
Ok(state)
}
pub fn load_binary(data: &[u8]) -> Result<Self, StateError> {
let reader = GzDecoder::new(data);
let state = serde_json::from_reader(reader).unwrap();
Ok(state)
}
}
#[cfg(test)]
@ -179,6 +186,31 @@ mod tests {
assert_eq!(params_before_1, params_after_2);
}
#[test]
fn test_load_binary() {
let model_1 = create_model();
let mut model_2 = create_model();
let params_before_1 = list_param_ids(&model_1);
let params_before_2 = list_param_ids(&model_2);
// Write to binary.
let state = model_1.state();
let mut binary = Vec::new();
let writer = GzEncoder::new(&mut binary, Compression::default());
serde_json::to_writer(writer, &state).unwrap();
// Load.
model_2.load(&State::load_binary(&binary).unwrap()).unwrap();
let params_after_2 = list_param_ids(&model_2);
// Verify.
assert_ne!(params_before_1, params_before_2);
assert_eq!(params_before_1, params_after_2);
}
fn create_model() -> nn::Linear<TestBackend> {
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
d_input: 32,

View File

@ -1,4 +1,4 @@
use burn::config::Config;
use burn::config::{config_to_json, Config};
use burn_core as burn;
#[derive(Config, Debug, PartialEq, Eq)]
@ -90,3 +90,13 @@ fn enum_config_should_impl_display() {
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
assert_eq!(burn::config::config_to_json(&config), config.to_string());
}
#[test]
fn struct_config_can_load_binary() {
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
let binary = config_to_json(&config).as_bytes().to_vec();
let config_loaded = TestStructConfig::load_binary(&binary).unwrap();
assert_eq!(config, config_loaded);
}