mirror of https://github.com/tracel-ai/burn.git
Fix warnings when using `record-backward-compat` (#1977)
This commit is contained in:
parent
8af2b719a1
commit
6f158af4b1
|
@ -1,15 +1,17 @@
|
||||||
use core::marker::PhantomData;
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
use super::{PrecisionSettings, Record};
|
use super::{PrecisionSettings, Record};
|
||||||
use alloc::format;
|
|
||||||
use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData};
|
use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "record-backward-compat"))]
|
||||||
|
use alloc::format;
|
||||||
#[cfg(feature = "record-backward-compat")]
|
#[cfg(feature = "record-backward-compat")]
|
||||||
use burn_tensor::DataSerialize;
|
use burn_tensor::DataSerialize;
|
||||||
|
|
||||||
/// Versioned serde data deserialization to maintain backward compatibility between formats.
|
/// Versioned serde data deserialization to maintain backward compatibility between formats.
|
||||||
#[cfg(feature = "record-backward-compat")]
|
#[cfg(feature = "record-backward-compat")]
|
||||||
|
#[allow(deprecated)]
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum TensorDataSerde<E> {
|
enum TensorDataSerde<E> {
|
||||||
|
@ -25,7 +27,7 @@ where
|
||||||
{
|
{
|
||||||
#[cfg(feature = "record-backward-compat")]
|
#[cfg(feature = "record-backward-compat")]
|
||||||
{
|
{
|
||||||
let data = match TensorDataSerde::<D, E>::deserialize(deserializer)? {
|
let data = match TensorDataSerde::<E>::deserialize(deserializer)? {
|
||||||
TensorDataSerde::V1(data) => data.into_tensor_data(),
|
TensorDataSerde::V1(data) => data.into_tensor_data(),
|
||||||
// NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16
|
// NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16
|
||||||
TensorDataSerde::V2(data) => data.convert::<E>(),
|
TensorDataSerde::V2(data) => data.convert::<E>(),
|
||||||
|
|
Loading…
Reference in New Issue