mirror of https://github.com/tracel-ai/burn.git
Fix: serde dependency (#1091)
* Re-export serde * Fix * USe another strategy * Fix * Fix * Update de book
This commit is contained in:
parent
b0a2b30ed1
commit
fceb036c6f
|
@ -1,8 +1,7 @@
|
|||
# Model
|
||||
|
||||
The first step is to create a project and add the different Burn dependencies. In the `Cargo.toml`
|
||||
file, add the `burn` dependency with `train` and `wgpu` features. Note that the `serde` dependency
|
||||
is also mandatory for the time being, as it is needed for serialization.
|
||||
file, add the `burn` dependency with `train` and `wgpu` features.
|
||||
|
||||
```toml
|
||||
[package]
|
||||
|
@ -12,9 +11,6 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
burn = { version = "0.12.0", features=["train", "wgpu"]}
|
||||
|
||||
# Serialization
|
||||
serde = "1"
|
||||
```
|
||||
|
||||
Our goal will be to create a basic convolutional neural network used for image classification. We
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// Re-export serde for proc macros.
|
||||
pub use serde;
|
||||
|
||||
/// The configuration module.
|
||||
pub mod config;
|
||||
|
||||
|
|
|
@ -22,7 +22,8 @@ impl ConfigEnumAnalyzer {
|
|||
let data = &self.data.variants;
|
||||
|
||||
quote! {
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
enum #enum_name {
|
||||
#data
|
||||
}
|
||||
|
@ -80,10 +81,10 @@ impl ConfigEnumAnalyzer {
|
|||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl serde::Serialize for #name {
|
||||
impl burn::serde::Serialize for #name {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
S: burn::serde::Serializer {
|
||||
let serde_state = match self {
|
||||
#(#variants),*
|
||||
};
|
||||
|
@ -105,10 +106,10 @@ impl ConfigEnumAnalyzer {
|
|||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl<'de> serde::Deserialize<'de> for #name {
|
||||
impl<'de> burn::serde::Deserialize<'de> for #name {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de> {
|
||||
D: burn::serde::Deserializer<'de> {
|
||||
let serde_state = #enum_name::deserialize(deserializer)?;
|
||||
Ok(match serde_state {
|
||||
#(#variants),*
|
||||
|
|
|
@ -85,12 +85,13 @@ impl ConfigStructAnalyzer {
|
|||
});
|
||||
|
||||
quote! {
|
||||
impl serde::Serialize for #name {
|
||||
impl burn::serde::Serialize for #name {
|
||||
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
#[derive(serde::Serialize)]
|
||||
S: burn::serde::Serializer {
|
||||
#[derive(burn::serde::Serialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#struct_gen
|
||||
|
||||
let serde_state = #struct_name {
|
||||
|
@ -116,11 +117,12 @@ impl ConfigStructAnalyzer {
|
|||
});
|
||||
|
||||
quote! {
|
||||
impl<'de> serde::Deserialize<'de> for #name {
|
||||
impl<'de> burn::serde::Deserialize<'de> for #name {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de> {
|
||||
#[derive(serde::Deserialize)]
|
||||
D: burn::serde::Deserializer<'de> {
|
||||
#[derive(burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#struct_gen
|
||||
|
||||
let serde_state = #struct_name::deserialize(deserializer)?;
|
||||
|
|
|
@ -24,7 +24,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
pub #name: <#ty as burn::record::Record>::Item<S>,
|
||||
});
|
||||
bounds.extend(quote! {
|
||||
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
|
||||
<#ty as burn::record::Record>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
});
|
||||
}
|
||||
let bound = bounds.to_string();
|
||||
|
@ -32,7 +32,8 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
quote! {
|
||||
|
||||
/// The record item type for the module.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, burn::serde::Serialize, burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#[serde(bound = #bound)]
|
||||
pub struct #item_name #generics {
|
||||
#fields
|
||||
|
|
Loading…
Reference in New Issue