Fix warnings when using `record-backward-compat` (#1977)

This commit is contained in:
Guillaume Lagrange 2024-07-08 07:58:50 -04:00 committed by GitHub
parent 8af2b719a1
commit 6f158af4b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -1,15 +1,17 @@
use core::marker::PhantomData; use core::marker::PhantomData;
use super::{PrecisionSettings, Record}; use super::{PrecisionSettings, Record};
use alloc::format;
use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData}; use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(not(feature = "record-backward-compat"))]
use alloc::format;
#[cfg(feature = "record-backward-compat")] #[cfg(feature = "record-backward-compat")]
use burn_tensor::DataSerialize; use burn_tensor::DataSerialize;
/// Versioned serde data deserialization to maintain backward compatibility between formats. /// Versioned serde data deserialization to maintain backward compatibility between formats.
#[cfg(feature = "record-backward-compat")] #[cfg(feature = "record-backward-compat")]
#[allow(deprecated)]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
enum TensorDataSerde<E> { enum TensorDataSerde<E> {
@ -25,7 +27,7 @@ where
{ {
#[cfg(feature = "record-backward-compat")] #[cfg(feature = "record-backward-compat")]
{ {
let data = match TensorDataSerde::<D, E>::deserialize(deserializer)? { let data = match TensorDataSerde::<E>::deserialize(deserializer)? {
TensorDataSerde::V1(data) => data.into_tensor_data(), TensorDataSerde::V1(data) => data.into_tensor_data(),
// NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16 // NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16
TensorDataSerde::V2(data) => data.convert::<E>(), TensorDataSerde::V2(data) => data.convert::<E>(),