mirror of https://github.com/tracel-ai/burn.git
Add feature to load config / state from binary slice. (#164)
This commit is contained in:
parent
2401d8ad96
commit
f907d0025d
|
@ -34,6 +34,13 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
|
|||
.map_err(|_| ConfigError::FileNotFound(file.to_string()))?;
|
||||
config_from_str(&content)
|
||||
}
|
||||
|
||||
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
|
||||
let content = std::str::from_utf8(data).map_err(|_| {
|
||||
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
|
||||
})?;
|
||||
config_from_str(content)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config_to_json<C: Config>(config: &C) -> String {
|
||||
|
|
|
@ -128,6 +128,13 @@ where
|
|||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
pub fn load_binary(data: &[u8]) -> Result<Self, StateError> {
|
||||
let reader = GzDecoder::new(data);
|
||||
let state = serde_json::from_reader(reader).unwrap();
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -179,6 +186,31 @@ mod tests {
|
|||
assert_eq!(params_before_1, params_after_2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_binary() {
|
||||
let model_1 = create_model();
|
||||
let mut model_2 = create_model();
|
||||
let params_before_1 = list_param_ids(&model_1);
|
||||
let params_before_2 = list_param_ids(&model_2);
|
||||
|
||||
// Write to binary.
|
||||
|
||||
let state = model_1.state();
|
||||
let mut binary = Vec::new();
|
||||
let writer = GzEncoder::new(&mut binary, Compression::default());
|
||||
serde_json::to_writer(writer, &state).unwrap();
|
||||
|
||||
// Load.
|
||||
|
||||
model_2.load(&State::load_binary(&binary).unwrap()).unwrap();
|
||||
let params_after_2 = list_param_ids(&model_2);
|
||||
|
||||
// Verify.
|
||||
|
||||
assert_ne!(params_before_1, params_before_2);
|
||||
assert_eq!(params_before_1, params_after_2);
|
||||
}
|
||||
|
||||
fn create_model() -> nn::Linear<TestBackend> {
|
||||
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
|
||||
d_input: 32,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn::config::Config;
|
||||
use burn::config::{config_to_json, Config};
|
||||
use burn_core as burn;
|
||||
|
||||
#[derive(Config, Debug, PartialEq, Eq)]
|
||||
|
@ -90,3 +90,13 @@ fn enum_config_should_impl_display() {
|
|||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_config_can_load_binary() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
|
||||
|
||||
let binary = config_to_json(&config).as_bytes().to_vec();
|
||||
|
||||
let config_loaded = TestStructConfig::load_binary(&binary).unwrap();
|
||||
assert_eq!(config, config_loaded);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue