Make all struct CamelCase (#1316)

This commit is contained in:
Dilshod Tadjibaev 2024-02-15 13:00:37 -06:00 committed by GitHub
parent dfb739c89a
commit 44266d5fd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 216 additions and 216 deletions

View File

@ -194,7 +194,7 @@ impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
``` ```
Subsequently, we'll go into implementing our custom backend trait for the WGPU backend. Subsequently, we'll go into implementing our custom backend trait for the WGPU backend.
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
we implement the trait for the raw `WgpuBackend` type. we implement the trait for the raw `WgpuBackend` type.
```rust, ignore ```rust, ignore

View File

@ -16,15 +16,15 @@ at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/exa
```rust , ignore ```rust , ignore
use burn::{ use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem}, data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
}; };
pub struct MNISTBatcher<B: Backend> { pub struct MnistBatcher<B: Backend> {
device: B::Device, device: B::Device,
} }
impl<B: Backend> MNISTBatcher<B> { impl<B: Backend> MnistBatcher<B> {
pub fn new(device: B::Device) -> Self { pub fn new(device: B::Device) -> Self {
Self { device } Self { device }
} }
@ -42,13 +42,13 @@ Next, we need to actually implement the batching logic.
```rust , ignore ```rust , ignore
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MNISTBatch<B: Backend> { pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 3>, pub images: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>, pub targets: Tensor<B, 1, Int>,
} }
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> { impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> { fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items let images = items
.iter() .iter()
.map(|item| Data::<f32, 2>::from(item.image)) .map(|item| Data::<f32, 2>::from(item.image))
@ -71,7 +71,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
let images = Tensor::cat(images, 0).to_device(&self.device); let images = Tensor::cat(images, 0).to_device(&self.device);
let targets = Tensor::cat(targets, 0).to_device(&self.device); let targets = Tensor::cat(targets, 0).to_device(&self.device);
MNISTBatch { images, targets } MnistBatch { images, targets }
} }
} }
``` ```
@ -81,7 +81,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
The iterator pattern allows you to perform some tasks on a sequence of items in turn. The iterator pattern allows you to perform some tasks on a sequence of items in turn.
In this example, an iterator is created over the `MNISTItem`s in the vector `items` by calling the In this example, an iterator is created over the `MnistItem`s in the vector `items` by calling the
`iter` method. `iter` method.
_Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by _Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by
@ -100,7 +100,7 @@ If we go back to the example, we can break down and comment the expression used
images. images.
```rust, ignore ```rust, ignore
let images = items // take items Vec<MNISTItem> let images = items // take items Vec<MnistItem>
.iter() // create an iterator over it .iter() // create an iterator over it
.map(|item| Data::<f32, 2>::from(item.image)) // for each item, convert the image to float32 data struct .map(|item| Data::<f32, 2>::from(item.image)) // for each item, convert the image to float32 data struct
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device .map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device
@ -115,8 +115,8 @@ Book.
</details><br> </details><br>
In the previous example, we implement the `Batcher` trait with a list of `MNISTItem` as input and a In the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a
single `MNISTBatch` as output. The batch contains the images in the form of a 3D tensor, along with single `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with
a targets tensor that contains the indexes of the correct digit class. The first step is to parse a targets tensor that contains the indexes of the correct digit class. The first step is to parse
the image array into a `Data` struct. Burn provides the `Data` struct to encapsulate tensor storage the image array into a `Data` struct. Burn provides the `Data` struct to encapsulate tensor storage
information without being specific for a backend. When creating a tensor from data, we often need to information without being specific for a backend. When creating a tensor from data, we often need to

View File

@ -16,7 +16,7 @@ impl ModelConfig {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(), activation: Relu::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
linear2: LinearConfig::new(self.hidden_size, self.num_classes) linear2: LinearConfig::new(self.hidden_size, self.num_classes)
.init_with(record.linear2), .init_with(record.linear2),
@ -33,7 +33,7 @@ manually. Everything is validated when loading the model with the record.
Now let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model. Now let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model.
```rust , ignore ```rust , ignore
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) { pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model"); .expect("Config should exist for the model");
let record = CompactRecorder::new() let record = CompactRecorder::new()
@ -43,7 +43,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
let model = config.model.init_with::<B>(record); let model = config.model.init_with::<B>(record);
let label = item.label; let label = item.label;
let batcher = MNISTBatcher::new(device); let batcher = MnistBatcher::new(device);
let batch = batcher.batch(vec![item]); let batch = batcher.batch(vec![item]);
let output = model.forward(batch.images); let output = model.forward(batch.images);
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
@ -56,6 +56,6 @@ The first step is to load the configuration of the training to fetch the correct
configuration. Then we can fetch the record using the same recorder as we used during training. configuration. Then we can fetch the record using the same recorder as we used during training.
Finally we can init the model with the configuration and the record before sending it to the wanted Finally we can init the model with the configuration and the record before sending it to the wanted
device for inference. For simplicity we can use the same batcher used during the training to pass device for inference. For simplicity we can use the same batcher used during the training to pass
from a MNISTItem to a tensor. from a MnistItem to a tensor.
By running the infer function, you should see the predictions of your model! By running the infer function, you should see the predictions of your model!

View File

@ -35,7 +35,7 @@ use burn::{
nn::{ nn::{
conv::{Conv2d, Conv2dConfig}, conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, ReLU, Dropout, DropoutConfig, Linear, LinearConfig, Relu,
}, },
tensor::{backend::Backend, Tensor}, tensor::{backend::Backend, Tensor},
}; };
@ -48,7 +48,7 @@ pub struct Model<B: Backend> {
dropout: Dropout, dropout: Dropout,
linear1: Linear<B>, linear1: Linear<B>,
linear2: Linear<B>, linear2: Linear<B>,
activation: ReLU, activation: Relu,
} }
``` ```
@ -98,7 +98,7 @@ There are two major things going on in this code sample.
pub struct MyCustomModule<B: Backend> { pub struct MyCustomModule<B: Backend> {
linear1: Linear<B>, linear1: Linear<B>,
linear2: Linear<B>, linear2: Linear<B>,
activation: ReLU, activation: Relu,
} }
``` ```
@ -178,7 +178,7 @@ impl ModelConfig {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(), activation: Relu::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
dropout: DropoutConfig::new(self.dropout).init(), dropout: DropoutConfig::new(self.dropout).init(),

View File

@ -43,23 +43,23 @@ Moving forward, we will proceed with the implementation of both the training and
for our model. for our model.
```rust , ignore ```rust , ignore
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> { impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> { fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.images, batch.targets); let item = self.forward_classification(batch.images, batch.targets);
TrainOutput::new(self, item.loss.backward(), item) TrainOutput::new(self, item.loss.backward(), item)
} }
} }
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> { impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> { fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.images, batch.targets) self.forward_classification(batch.images, batch.targets)
} }
} }
``` ```
Here we define the input and output types as generic arguments in the `TrainStep` and `ValidStep`. Here we define the input and output types as generic arguments in the `TrainStep` and `ValidStep`.
We will call them `MNISTBatch` and `ClassificationOutput`. In the training step, the computation of We will call them `MnistBatch` and `ClassificationOutput`. In the training step, the computation of
gradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note gradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note
that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather
returned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a returned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a
@ -81,8 +81,8 @@ which is generic over the `Backend` trait as has been covered before. These trai
`burn::train` and define a common `step` method that should be implemented for all structs. Since `burn::train` and define a common `step` method that should be implemented for all structs. Since
the trait is generic over the input and output types, the trait implementation must specify the the trait is generic over the input and output types, the trait implementation must specify the
concrete types used. This is where the additional type constraints appear concrete types used. This is where the additional type constraints appear
`<MNISTBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the `<MnistBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the
batch is `MNISTBatch`, and the output of the forward pass is `ClassificationOutput`. The `step` batch is `MnistBatch`, and the output of the forward pass is `ClassificationOutput`. The `step`
method signature matches the concrete input and output types. method signature matches the concrete input and output types.
For more details specific to constraints on generic types when defining methods, take a look at For more details specific to constraints on generic types when defining methods, take a look at
@ -118,20 +118,20 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
B::seed(config.seed); B::seed(config.seed);
let batcher_train = MNISTBatcher::<B>::new(device.clone()); let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone()); let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train) let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::train()); .build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid) let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::test()); .build(MnistDataset::test());
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(AccuracyMetric::new()) .metric_train_numeric(AccuracyMetric::new())

View File

@ -160,4 +160,4 @@ Burn comes with built-in modules that you can use to build your own modules.
| Burn API | PyTorch Equivalent | | Burn API | PyTorch Equivalent |
| ------------------ | --------------------- | | ------------------ | --------------------- |
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MSELoss` | `nn.MSELoss` | | `MseLoss` | `nn.MSELoss` |

View File

@ -40,21 +40,21 @@ pub fn run<B: AutodiffBackend>(device: &B::Device) {
let mut optim = config.optimizer.init(); let mut optim = config.optimizer.init();
// Create the batcher. // Create the batcher.
let batcher_train = MNISTBatcher::<B>::new(device.clone()); let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone()); let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
// Create the dataloaders. // Create the dataloaders.
let dataloader_train = DataLoaderBuilder::new(batcher_train) let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::train()); .build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid) let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::test()); .build(MnistDataset::test());
... ...
} }
@ -140,7 +140,7 @@ Note that after each epoch, we include a validation loop to assess our model's p
previously unseen data. To disable gradient tracking during this validation step, we can invoke previously unseen data. To disable gradient tracking during this validation step, we can invoke
`model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's `model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's
important to emphasize that we've declared our validation batcher to be on the inner backend, important to emphasize that we've declared our validation batcher to be on the inner backend,
specifically `MNISTBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation specifically `MnistBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation
error. error.
You can find the code above available as an You can find the code above available as an
@ -195,7 +195,7 @@ where
M: AutodiffModule<B>, M: AutodiffModule<B>,
O: Optimizer<M, B>, O: Optimizer<M, B>,
{ {
pub fn step(&mut self, _batch: MNISTBatch<B>) { pub fn step(&mut self, _batch: MnistBatch<B>) {
// //
} }
} }
@ -214,7 +214,7 @@ the backend and add your trait constraint within its definition:
```rust, ignore ```rust, ignore
#[allow(dead_code)] #[allow(dead_code)]
impl<M, O> Learner2<M, O> { impl<M, O> Learner2<M, O> {
pub fn step<B: AutodiffBackend>(&mut self, _batch: MNISTBatch<B>) pub fn step<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
where where
B: AutodiffBackend, B: AutodiffBackend,
M: AutodiffModule<B>, M: AutodiffModule<B>,

View File

@ -44,7 +44,7 @@ model definition as a simple example.
pub struct Model<B: Backend> { pub struct Model<B: Backend> {
linear_in: Linear<B>, linear_in: Linear<B>,
linear_out: Linear<B>, linear_out: Linear<B>,
activation: ReLU, activation: Relu,
} }
``` ```
@ -59,7 +59,7 @@ impl<B: Backend> Model<B> {
Model { Model {
linear_in: LinearConfig::new(10, 64).init_with(record.linear_in), linear_in: LinearConfig::new(10, 64).init_with(record.linear_in),
linear_out: LinearConfig::new(64, 2).init_with(record.linear_out), linear_out: LinearConfig::new(64, 2).init_with(record.linear_out),
activation: ReLU::new(), activation: Relu::new(),
} }
} }
@ -70,7 +70,7 @@ impl<B: Backend> Model<B> {
Model { Model {
linear_in: l1, linear_in: l1,
linear_out: l2, linear_out: l2,
activation: ReLU::new(), activation: Relu::new(),
} }
} }
} }

View File

@ -5,17 +5,17 @@ use burn_tensor::{backend::Backend, Tensor};
/// Calculate the mean squared error loss from the input logits and the targets. /// Calculate the mean squared error loss from the input logits and the targets.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MSELoss<B: Backend> { pub struct MseLoss<B: Backend> {
backend: PhantomData<B>, backend: PhantomData<B>,
} }
impl<B: Backend> Default for MSELoss<B> { impl<B: Backend> Default for MseLoss<B> {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
impl<B: Backend> MSELoss<B> { impl<B: Backend> MseLoss<B> {
/// Create the criterion. /// Create the criterion.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -67,7 +67,7 @@ mod tests {
let targets = let targets =
Tensor::<TestBackend, 2>::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]]), &device); Tensor::<TestBackend, 2>::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]]), &device);
let mse = MSELoss::new(); let mse = MseLoss::new();
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
let loss_sum = mse.forward(logits, targets, Reduction::Sum); let loss_sum = mse.forward(logits, targets, Reduction::Sum);

View File

@ -8,9 +8,9 @@ use crate::tensor::Tensor;
/// ///
/// `y = max(0, x)` /// `y = max(0, x)`
#[derive(Module, Clone, Debug, Default)] #[derive(Module, Clone, Debug, Default)]
pub struct ReLU {} pub struct Relu {}
impl ReLU { impl Relu {
/// Create the module. /// Create the module.
pub fn new() -> Self { pub fn new() -> Self {
Self {} Self {}

View File

@ -27,14 +27,14 @@ pub struct AdaGradConfig {
/// AdaGrad optimizer /// AdaGrad optimizer
pub struct AdaGrad<B: Backend> { pub struct AdaGrad<B: Backend> {
lr_decay: LRDecay, lr_decay: LrDecay,
weight_decay: Option<WeightDecay<B>>, weight_decay: Option<WeightDecay<B>>,
} }
/// AdaGrad state. /// AdaGrad state.
#[derive(Record, Clone, new)] #[derive(Record, Clone, new)]
pub struct AdaGradState<B: Backend, const D: usize> { pub struct AdaGradState<B: Backend, const D: usize> {
lr_decay: LRDecayState<B, D>, lr_decay: LrDecayState<B, D>,
} }
impl<B: Backend> SimpleOptimizer<B> for AdaGrad<B> { impl<B: Backend> SimpleOptimizer<B> for AdaGrad<B> {
@ -81,7 +81,7 @@ impl AdaGradConfig {
/// Returns an optimizer that can be used to optimize a module. /// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> impl Optimizer<M, B> { pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> impl Optimizer<M, B> {
let optim = AdaGrad { let optim = AdaGrad {
lr_decay: LRDecay { lr_decay: LrDecay {
lr_decay: self.lr_decay, lr_decay: self.lr_decay,
epsilon: self.epsilon, epsilon: self.epsilon,
}, },
@ -98,29 +98,29 @@ impl AdaGradConfig {
/// Learning rate decay state (also includes sum state). /// Learning rate decay state (also includes sum state).
#[derive(Record, new, Clone)] #[derive(Record, new, Clone)]
pub struct LRDecayState<B: Backend, const D: usize> { pub struct LrDecayState<B: Backend, const D: usize> {
time: usize, time: usize,
sum: Tensor<B, D>, sum: Tensor<B, D>,
} }
struct LRDecay { struct LrDecay {
lr_decay: f64, lr_decay: f64,
epsilon: f32, epsilon: f32,
} }
impl LRDecay { impl LrDecay {
pub fn transform<B: Backend, const D: usize>( pub fn transform<B: Backend, const D: usize>(
&self, &self,
grad: Tensor<B, D>, grad: Tensor<B, D>,
lr: LearningRate, lr: LearningRate,
lr_decay_state: Option<LRDecayState<B, D>>, lr_decay_state: Option<LrDecayState<B, D>>,
) -> (Tensor<B, D>, LRDecayState<B, D>) { ) -> (Tensor<B, D>, LrDecayState<B, D>) {
let state = if let Some(mut state) = lr_decay_state { let state = if let Some(mut state) = lr_decay_state {
state.sum = state.sum.add(grad.clone().powf_scalar(2.)); state.sum = state.sum.add(grad.clone().powf_scalar(2.));
state.time += 1; state.time += 1;
state state
} else { } else {
LRDecayState::new(1, grad.clone().powf_scalar(2.)) LrDecayState::new(1, grad.clone().powf_scalar(2.))
}; };
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
@ -133,7 +133,7 @@ impl LRDecay {
} }
} }
impl<B: Backend, const D: usize> LRDecayState<B, D> { impl<B: Backend, const D: usize> LrDecayState<B, D> {
/// Move state to device. /// Move state to device.
/// ///
/// # Arguments /// # Arguments
@ -278,7 +278,7 @@ mod tests {
{ {
let config = AdaGradConfig::new(); let config = AdaGradConfig::new();
AdaGrad { AdaGrad {
lr_decay: LRDecay { lr_decay: LrDecay {
lr_decay: config.lr_decay, lr_decay: config.lr_decay,
epsilon: config.epsilon, epsilon: config.epsilon,
}, },

View File

@ -12,19 +12,19 @@ use crate::optim::adaptor::OptimizerAdaptor;
use crate::tensor::{backend::AutodiffBackend, Tensor}; use crate::tensor::{backend::AutodiffBackend, Tensor};
use burn_tensor::backend::Backend; use burn_tensor::backend::Backend;
/// Configuration to create the [RMSProp](RMSProp) optimizer. /// Configuration to create the [RmsProp](RmsProp) optimizer.
#[derive(Config)] #[derive(Config)]
pub struct RMSPropConfig { pub struct RmsPropConfig {
/// Smoothing constant. /// Smoothing constant.
#[config(default = 0.99)] #[config(default = 0.99)]
alpha: f32, alpha: f32,
/// momentum for RMSProp. /// momentum for RmsProp.
#[config(default = 0.9)] #[config(default = 0.9)]
momentum: f32, momentum: f32,
/// A value required for numerical stability. /// A value required for numerical stability.
#[config(default = 1e-5)] #[config(default = 1e-5)]
epsilon: f32, epsilon: f32,
/// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance /// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
#[config(default = false)] #[config(default = false)]
centered: bool, centered: bool,
/// [Weight decay](WeightDecayConfig) config. /// [Weight decay](WeightDecayConfig) config.
@ -33,22 +33,22 @@ pub struct RMSPropConfig {
grad_clipping: Option<GradientClippingConfig>, grad_clipping: Option<GradientClippingConfig>,
} }
impl RMSPropConfig { impl RmsPropConfig {
/// Initialize RMSProp optimizer. /// Initialize RmsProp optimizer.
/// ///
/// # Returns /// # Returns
/// ///
/// Returns an optimizer that can be used to optimize a module. /// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>( pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self, &self,
) -> OptimizerAdaptor<RMSProp<B::InnerBackend>, M, B> { ) -> OptimizerAdaptor<RmsProp<B::InnerBackend>, M, B> {
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
let mut optim = OptimizerAdaptor::from(RMSProp { let mut optim = OptimizerAdaptor::from(RmsProp {
alpha: self.alpha, alpha: self.alpha,
centered: self.centered, centered: self.centered,
weight_decay, weight_decay,
momentum: RMSPropMomentum { momentum: RmsPropMomentum {
momentum: self.momentum, momentum: self.momentum,
epsilon: self.epsilon, epsilon: self.epsilon,
}, },
@ -63,18 +63,18 @@ impl RMSPropConfig {
} }
/// Optimizer that implements stochastic gradient descent with momentum. /// Optimizer that implements stochastic gradient descent with momentum.
/// The optimizer can be configured with [RMSPropConfig](RMSPropConfig). /// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
pub struct RMSProp<B: Backend> { pub struct RmsProp<B: Backend> {
alpha: f32, alpha: f32,
// epsilon: f32, // epsilon: f32,
centered: bool, centered: bool,
// momentum: Option<Momentum<B>>, // momentum: Option<Momentum<B>>,
momentum: RMSPropMomentum, momentum: RmsPropMomentum,
weight_decay: Option<WeightDecay<B>>, weight_decay: Option<WeightDecay<B>>,
} }
impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> { impl<B: Backend> SimpleOptimizer<B> for RmsProp<B> {
type State<const D: usize> = RMSPropState<B, D>; type State<const D: usize> = RmsPropState<B, D>;
fn step<const D: usize>( fn step<const D: usize>(
&self, &self,
@ -117,7 +117,7 @@ impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
.transform(grad, state_centered, state_momentum); .transform(grad, state_centered, state_momentum);
// transition state // transition state
let state = RMSPropState::new(state_square_avg, state_centered, state_momentum); let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
// tensor param transform // tensor param transform
let delta = grad.mul_scalar(lr); let delta = grad.mul_scalar(lr);
@ -135,12 +135,12 @@ impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
} }
} }
/// State of [RMSProp](RMSProp) /// State of [RmsProp](RmsProp)
#[derive(Record, Clone, new)] #[derive(Record, Clone, new)]
pub struct RMSPropState<B: Backend, const D: usize> { pub struct RmsPropState<B: Backend, const D: usize> {
square_avg: SquareAvgState<B, D>, square_avg: SquareAvgState<B, D>,
centered: CenteredState<B, D>, centered: CenteredState<B, D>,
momentum: Option<RMSPropMomentumState<B, D>>, momentum: Option<RmsPropMomentumState<B, D>>,
} }
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params. /// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
@ -249,24 +249,24 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
} }
} }
/// [RMSPropMomentum](RMSPropMomentum) is to store config status for optimizer. /// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
/// (, which is stored in [optimizer](RMSProp) itself and not passed in during `step()` calculation) /// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
pub struct RMSPropMomentum { pub struct RmsPropMomentum {
momentum: f32, momentum: f32,
epsilon: f32, epsilon: f32,
} }
impl RMSPropMomentum { impl RmsPropMomentum {
/// transform [grad](Tensor) and [RMSPropMomentumState] to the next step /// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
fn transform<B: Backend, const D: usize>( fn transform<B: Backend, const D: usize>(
&self, &self,
grad: Tensor<B, D>, grad: Tensor<B, D>,
centered_state: CenteredState<B, D>, centered_state: CenteredState<B, D>,
momentum_state: Option<RMSPropMomentumState<B, D>>, momentum_state: Option<RmsPropMomentumState<B, D>>,
) -> ( ) -> (
Tensor<B, D>, Tensor<B, D>,
CenteredState<B, D>, CenteredState<B, D>,
Option<RMSPropMomentumState<B, D>>, Option<RmsPropMomentumState<B, D>>,
) { ) {
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
@ -278,7 +278,7 @@ impl RMSPropMomentum {
( (
buf.clone(), buf.clone(),
centered_state, centered_state,
Some(RMSPropMomentumState { buf }), Some(RmsPropMomentumState { buf }),
) )
} else { } else {
(grad, centered_state, None) (grad, centered_state, None)
@ -286,13 +286,13 @@ impl RMSPropMomentum {
} }
} }
/// [RMSPropMomentumState](RMSPropMomentumState) is to store and pass optimizer step params. /// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
#[derive(Record, Clone, new)] #[derive(Record, Clone, new)]
pub struct RMSPropMomentumState<B: Backend, const D: usize> { pub struct RmsPropMomentumState<B: Backend, const D: usize> {
buf: Tensor<B, D>, buf: Tensor<B, D>,
} }
impl<B: Backend, const D: usize> RMSPropMomentumState<B, D> { impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
/// Moves the state to a device. /// Moves the state to a device.
/// ///
/// # Arguments /// # Arguments
@ -378,7 +378,7 @@ mod tests {
) )
.require_grad(); .require_grad();
let mut optimizer = RMSPropConfig::new() let mut optimizer = RmsPropConfig::new()
.with_alpha(0.99) .with_alpha(0.99)
.with_epsilon(1e-8) .with_epsilon(1e-8)
.with_weight_decay(WeightDecayConfig::new(0.05).into()) .with_weight_decay(WeightDecayConfig::new(0.05).into())
@ -453,7 +453,7 @@ mod tests {
) )
.require_grad(); .require_grad();
let mut optimizer = RMSPropConfig::new() let mut optimizer = RmsPropConfig::new()
.with_alpha(0.99) .with_alpha(0.99)
.with_epsilon(1e-8) .with_epsilon(1e-8)
.with_weight_decay(WeightDecayConfig::new(0.05).into()) .with_weight_decay(WeightDecayConfig::new(0.05).into())
@ -529,9 +529,9 @@ mod tests {
} }
fn create_rmsprop( fn create_rmsprop(
) -> OptimizerAdaptor<RMSProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend> ) -> OptimizerAdaptor<RmsProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
{ {
RMSPropConfig { RmsPropConfig {
alpha: 0.99, alpha: 0.99,
epsilon: 1e-9, epsilon: 1e-9,
centered: false, centered: false,

View File

@ -49,12 +49,12 @@ pub enum ImporterError {
/// use serde::{Deserialize, Serialize}; /// use serde::{Deserialize, Serialize};
/// ///
/// #[derive(Deserialize, Debug, Clone)] /// #[derive(Deserialize, Debug, Clone)]
/// struct MNISTItemRaw { /// struct MnistItemRaw {
/// pub image_bytes: Vec<u8>, /// pub image_bytes: Vec<u8>,
/// pub label: usize, /// pub label: usize,
/// } /// }
/// ///
/// let train_ds:SqliteDataset<MNISTItemRaw> = HuggingfaceDatasetLoader::new("mnist") /// let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
/// .dataset("train") /// .dataset("train")
/// .unwrap(); /// .unwrap();
pub struct HuggingfaceDatasetLoader { pub struct HuggingfaceDatasetLoader {

View File

@ -24,7 +24,7 @@ const HEIGHT: usize = 28;
/// MNIST item. /// MNIST item.
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MNISTItem { pub struct MnistItem {
/// Image as a 2D array of floats. /// Image as a 2D array of floats.
pub image: [[f32; WIDTH]; HEIGHT], pub image: [[f32; WIDTH]; HEIGHT],
@ -33,16 +33,16 @@ pub struct MNISTItem {
} }
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
struct MNISTItemRaw { struct MnistItemRaw {
pub image_bytes: Vec<u8>, pub image_bytes: Vec<u8>,
pub label: u8, pub label: u8,
} }
struct BytesToImage; struct BytesToImage;
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage { impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
fn map(&self, item: &MNISTItemRaw) -> MNISTItem { fn map(&self, item: &MnistItemRaw) -> MnistItem {
// Ensure the image dimensions are correct. // Ensure the image dimensions are correct.
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT); debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
@ -54,25 +54,25 @@ impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
image_array[y][x] = *pixel as f32; image_array[y][x] = *pixel as f32;
} }
MNISTItem { MnistItem {
image: image_array, image: image_array,
label: item.label, label: item.label,
} }
} }
} }
type MappedDataset = MapperDataset<InMemDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>; type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000 /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
/// images per class. There are 60,000 training images and 10,000 test images. /// images per class. There are 60,000 training images and 10,000 test images.
/// ///
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist). /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
pub struct MNISTDataset { pub struct MnistDataset {
dataset: MappedDataset, dataset: MappedDataset,
} }
impl Dataset<MNISTItem> for MNISTDataset { impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MNISTItem> { fn get(&self, index: usize) -> Option<MnistItem> {
self.dataset.get(index) self.dataset.get(index)
} }
@ -81,7 +81,7 @@ impl Dataset<MNISTItem> for MNISTDataset {
} }
} }
impl MNISTDataset { impl MnistDataset {
/// Creates a new train dataset. /// Creates a new train dataset.
pub fn train() -> Self { pub fn train() -> Self {
Self::new("train") Self::new("train")
@ -94,19 +94,19 @@ impl MNISTDataset {
fn new(split: &str) -> Self { fn new(split: &str) -> Self {
// Download dataset // Download dataset
let root = MNISTDataset::download(split); let root = MnistDataset::download(split);
// MNIST is tiny so we can load it in-memory // MNIST is tiny so we can load it in-memory
// Train images (u8): 28 * 28 * 60000 = 47.04Mb // Train images (u8): 28 * 28 * 60000 = 47.04Mb
// Test images (u8): 28 * 28 * 10000 = 7.84Mb // Test images (u8): 28 * 28 * 10000 = 7.84Mb
let images = MNISTDataset::read_images(&root, split); let images = MnistDataset::read_images(&root, split);
let labels = MNISTDataset::read_labels(&root, split); let labels = MnistDataset::read_labels(&root, split);
// Collect as vector of MNISTItemRaw // Collect as vector of MnistItemRaw
let items: Vec<_> = images let items: Vec<_> = images
.into_iter() .into_iter()
.zip(labels) .zip(labels)
.map(|(image_bytes, label)| MNISTItemRaw { image_bytes, label }) .map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
.collect(); .collect();
let dataset = InMemDataset::new(items); let dataset = InMemDataset::new(items);
@ -132,12 +132,12 @@ impl MNISTDataset {
// Download split files // Download split files
match split { match split {
"train" => { "train" => {
MNISTDataset::download_file(TRAIN_IMAGES, &split_dir); MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
MNISTDataset::download_file(TRAIN_LABELS, &split_dir); MnistDataset::download_file(TRAIN_LABELS, &split_dir);
} }
"test" => { "test" => {
MNISTDataset::download_file(TEST_IMAGES, &split_dir); MnistDataset::download_file(TEST_IMAGES, &split_dir);
MNISTDataset::download_file(TEST_LABELS, &split_dir); MnistDataset::download_file(TEST_LABELS, &split_dir);
} }
_ => panic!("Invalid split specified {}", split), _ => panic!("Invalid split specified {}", split),
}; };

View File

@ -1,6 +1,6 @@
use burn::{ use burn::{
module::Module, module::Module,
nn::{Linear, LinearConfig, ReLU}, nn::{Linear, LinearConfig, Relu},
tensor::{backend::Backend, Tensor}, tensor::{backend::Backend, Tensor},
}; };
@ -8,7 +8,7 @@ use burn::{
pub struct Net<B: Backend> { pub struct Net<B: Backend> {
fc1: Linear<B>, fc1: Linear<B>,
fc2: Linear<B>, fc2: Linear<B>,
relu: ReLU, relu: Relu,
} }
impl<B: Backend> Net<B> { impl<B: Backend> Net<B> {
@ -16,7 +16,7 @@ impl<B: Backend> Net<B> {
pub fn new_with(record: NetRecord<B>) -> Self { pub fn new_with(record: NetRecord<B>) -> Self {
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1); let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
let fc2 = LinearConfig::new(3, 4).init_with(record.fc2); let fc2 = LinearConfig::new(3, 4).init_with(record.fc2);
let relu = ReLU::default(); let relu = Relu::default();
Self { fc1, fc2, relu } Self { fc1, fc2, relu }
} }

View File

@ -9,7 +9,7 @@ use crate::onnx::{
proto_conversion::convert_node_proto, proto_conversion::convert_node_proto,
}; };
use super::ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}; use super::ir::{ArgType, Argument, Node, NodeType, OnnxGraph, Tensor};
use super::protos::{ModelProto, TensorProto}; use super::protos::{ModelProto, TensorProto};
use super::{dim_inference::dim_inference, protos::ValueInfoProto}; use super::{dim_inference::dim_inference, protos::ValueInfoProto};
@ -33,14 +33,14 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [
/// ///
/// # Returns /// # Returns
/// ///
/// * `ONNXGraph` - The graph representation of the onnx file /// * `OnnxGraph` - The graph representation of the onnx file
/// ///
/// # Panics /// # Panics
/// ///
/// * If the file cannot be opened /// * If the file cannot be opened
/// * If the file cannot be parsed /// * If the file cannot be parsed
/// * If the nodes are not topologically sorted /// * If the nodes are not topologically sorted
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
log::info!("Parsing ONNX file: {}", onnx_path.display()); log::info!("Parsing ONNX file: {}", onnx_path.display());
// Open the file // Open the file
@ -118,7 +118,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
log::info!("Finished parsing ONNX file: {}", onnx_path.display()); log::info!("Finished parsing ONNX file: {}", onnx_path.display());
ONNXGraph { OnnxGraph {
nodes, nodes,
inputs, inputs,
outputs, outputs,

View File

@ -129,7 +129,7 @@ pub enum Data {
/// ONNX graph representation /// ONNX graph representation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ONNXGraph { pub struct OnnxGraph {
/// The nodes of the graph. /// The nodes of the graph.
pub nodes: Vec<Node>, pub nodes: Vec<Node>,

View File

@ -11,4 +11,4 @@ mod to_burn;
pub use to_burn::*; pub use to_burn::*;
pub use from_onnx::parse_onnx; pub use from_onnx::parse_onnx;
pub use ir::ONNXGraph; pub use ir::OnnxGraph;

View File

@ -45,7 +45,7 @@ use crate::{
use super::{ use super::{
from_onnx::parse_onnx, from_onnx::parse_onnx,
ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph}, ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph},
op_configuration::{ op_configuration::{
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
softmax_config, softmax_config,
@ -218,7 +218,7 @@ impl ModelGen {
} }
} }
impl ONNXGraph { impl OnnxGraph {
/// Converts ONNX graph to Burn graph. /// Converts ONNX graph to Burn graph.
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> { pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
let mut graph = BurnGraph::<PS>::default(); let mut graph = BurnGraph::<PS>::default();

View File

@ -28,7 +28,7 @@ pub struct MlpConfig {
pub struct Mlp<B: Backend> { pub struct Mlp<B: Backend> {
linears: Vec<nn::Linear<B>>, linears: Vec<nn::Linear<B>>,
dropout: nn::Dropout, dropout: nn::Dropout,
activation: nn::ReLU, activation: nn::Relu,
} }
impl<B: Backend> Mlp<B> { impl<B: Backend> Mlp<B> {
@ -43,7 +43,7 @@ impl<B: Backend> Mlp<B> {
Self { Self {
linears, linears,
dropout: nn::DropoutConfig::new(0.3).init(), dropout: nn::DropoutConfig::new(0.3).init(),
activation: nn::ReLU::new(), activation: nn::Relu::new(),
} }
} }

View File

@ -3,11 +3,11 @@ use crate::metric::{Metric, MetricEntry};
use nvml_wrapper::Nvml; use nvml_wrapper::Nvml;
/// Track basic cuda infos. /// Track basic cuda infos.
pub struct CUDAMetric { pub struct CudaMetric {
nvml: Option<Nvml>, nvml: Option<Nvml>,
} }
impl CUDAMetric { impl CudaMetric {
/// Creates a new metric for CUDA. /// Creates a new metric for CUDA.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -19,7 +19,7 @@ impl CUDAMetric {
} }
} }
impl Default for CUDAMetric { impl Default for CudaMetric {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
@ -29,7 +29,7 @@ impl<T> Adaptor<()> for T {
fn adapt(&self) {} fn adapt(&self) {}
} }
impl Metric for CUDAMetric { impl Metric for CudaMetric {
const NAME: &'static str = "CUDA Stats"; const NAME: &'static str = "CUDA Stats";
type Input = (); type Input = ();

View File

@ -61,7 +61,7 @@
//! - `audio`: Enables audio datasets (SpeechCommandsDataset) //! - `audio`: Enables audio datasets (SpeechCommandsDataset)
//! - `sqlite`: Stores datasets in SQLite database //! - `sqlite`: Stores datasets in SQLite database
//! - `sqlite_bundled`: Use bundled version of SQLite //! - `sqlite_bundled`: Use bundled version of SQLite
//! - `vision`: Enables vision datasets (MNISTDataset) //! - `vision`: Enables vision datasets (MnistDataset)
//! - Backends //! - Backends
//! - `wgpu`: Makes available the WGPU backend //! - `wgpu`: Makes available the WGPU backend
//! - `candle`: Makes available the Candle backend //! - `candle`: Makes available the Candle backend

View File

@ -3,7 +3,7 @@ use burn::{
nn::{ nn::{
conv::{Conv2d, Conv2dConfig}, conv::{Conv2d, Conv2dConfig},
pool::{MaxPool2d, MaxPool2dConfig}, pool::{MaxPool2d, MaxPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, ReLU, Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
}, },
tensor::{backend::Backend, Device, Tensor}, tensor::{backend::Backend, Device, Tensor},
}; };
@ -23,8 +23,8 @@ use burn::{
// │ maxpool │ // │ maxpool │
// └────────────────────┘ // └────────────────────┘
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct CNN<B: Backend> { pub struct Cnn<B: Backend> {
activation: ReLU, activation: Relu,
dropout: Dropout, dropout: Dropout,
pool: MaxPool2d, pool: MaxPool2d,
conv1: Conv2d<B>, conv1: Conv2d<B>,
@ -37,7 +37,7 @@ pub struct CNN<B: Backend> {
fc2: Linear<B>, fc2: Linear<B>,
} }
impl<B: Backend> CNN<B> { impl<B: Backend> Cnn<B> {
pub fn new(num_classes: usize, device: &Device<B>) -> Self { pub fn new(num_classes: usize, device: &Device<B>) -> Self {
let conv1 = Conv2dConfig::new([3, 32], [3, 3]) let conv1 = Conv2dConfig::new([3, 32], [3, 3])
.with_padding(PaddingConfig2d::Same) .with_padding(PaddingConfig2d::Same)
@ -68,7 +68,7 @@ impl<B: Backend> CNN<B> {
let dropout = DropoutConfig::new(0.3).init(); let dropout = DropoutConfig::new(0.3).init();
Self { Self {
activation: ReLU::new(), activation: Relu::new(),
dropout, dropout,
pool, pool,
conv1, conv1,

View File

@ -3,7 +3,7 @@ use std::time::Instant;
use crate::{ use crate::{
data::{ClassificationBatch, ClassificationBatcher}, data::{ClassificationBatch, ClassificationBatcher},
dataset::CIFAR10Loader, dataset::CIFAR10Loader,
model::CNN, model::Cnn,
}; };
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset}; use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
use burn::train::{ use burn::train::{
@ -26,7 +26,7 @@ use burn::{
const NUM_CLASSES: u8 = 10; const NUM_CLASSES: u8 = 10;
const ARTIFACT_DIR: &str = "/tmp/custom-image-dataset"; const ARTIFACT_DIR: &str = "/tmp/custom-image-dataset";
impl<B: Backend> CNN<B> { impl<B: Backend> Cnn<B> {
pub fn forward_classification( pub fn forward_classification(
&self, &self,
images: Tensor<B, 4>, images: Tensor<B, 4>,
@ -41,7 +41,7 @@ impl<B: Backend> CNN<B> {
} }
} }
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<B>> for CNN<B> { impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<B>> for Cnn<B> {
fn step(&self, batch: ClassificationBatch<B>) -> TrainOutput<ClassificationOutput<B>> { fn step(&self, batch: ClassificationBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.images, batch.targets); let item = self.forward_classification(batch.images, batch.targets);
@ -49,7 +49,7 @@ impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<
} }
} }
impl<B: Backend> ValidStep<ClassificationBatch<B>, ClassificationOutput<B>> for CNN<B> { impl<B: Backend> ValidStep<ClassificationBatch<B>, ClassificationOutput<B>> for Cnn<B> {
fn step(&self, batch: ClassificationBatch<B>) -> ClassificationOutput<B> { fn step(&self, batch: ClassificationBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.images, batch.targets) self.forward_classification(batch.images, batch.targets)
} }
@ -104,7 +104,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
.devices(vec![device.clone()]) .devices(vec![device.clone()])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)
.build( .build(
CNN::new(NUM_CLASSES.into(), &device), Cnn::new(NUM_CLASSES.into(), &device),
config.optimizer.init(), config.optimizer.init(),
config.learning_rate, config.learning_rate,
); );

View File

@ -1,11 +1,11 @@
use burn::data::dataset::vision::MNISTDataset; use burn::data::dataset::vision::MnistDataset;
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress}; use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::LearnerBuilder; use burn::train::LearnerBuilder;
use burn::{ use burn::{
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
tensor::backend::AutodiffBackend, tensor::backend::AutodiffBackend,
}; };
use guide::{data::MNISTBatcher, model::ModelConfig}; use guide::{data::MnistBatcher, model::ModelConfig};
#[derive(Config)] #[derive(Config)]
pub struct MnistTrainingConfig { pub struct MnistTrainingConfig {
@ -52,21 +52,21 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
let optim = config.optimizer.init(); let optim = config.optimizer.init();
// Create the batcher. // Create the batcher.
let batcher_train = MNISTBatcher::<B>::new(device.clone()); let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone()); let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
// Create the dataloaders. // Create the dataloaders.
let dataloader_train = DataLoaderBuilder::new(batcher_train) let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::train()); .build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid) let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::test()); .build(MnistDataset::test());
// artifact dir does not need to be provided when log_to_file is false // artifact dir does not need to be provided when log_to_file is false
let builder = LearnerBuilder::new("") let builder = LearnerBuilder::new("")

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use burn::data::dataset::vision::MNISTDataset; use burn::data::dataset::vision::MnistDataset;
use burn::{ use burn::{
config::Config, config::Config,
data::dataloader::DataLoaderBuilder, data::dataloader::DataLoaderBuilder,
@ -13,7 +13,7 @@ use burn::{
}, },
}; };
use guide::{ use guide::{
data::{MNISTBatch, MNISTBatcher}, data::{MnistBatch, MnistBatcher},
model::{Model, ModelConfig}, model::{Model, ModelConfig},
}; };
@ -46,21 +46,21 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
let mut optim = config.optimizer.init(); let mut optim = config.optimizer.init();
// Create the batcher. // Create the batcher.
let batcher_train = MNISTBatcher::<B>::new(device.clone()); let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone()); let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
// Create the dataloaders. // Create the dataloaders.
let dataloader_train = DataLoaderBuilder::new(batcher_train) let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::train()); .build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid) let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::test()); .build(MnistDataset::test());
// Iterate over our training and validation loop for X epochs. // Iterate over our training and validation loop for X epochs.
for epoch in 1..config.num_epochs + 1 { for epoch in 1..config.num_epochs + 1 {
@ -145,7 +145,7 @@ where
B: AutodiffBackend, B: AutodiffBackend,
O: Optimizer<Model<B>, B>, O: Optimizer<Model<B>, B>,
{ {
pub fn step1(&mut self, _batch: MNISTBatch<B>) { pub fn step1(&mut self, _batch: MnistBatch<B>) {
// //
} }
} }
@ -156,14 +156,14 @@ where
B: AutodiffBackend, B: AutodiffBackend,
O: Optimizer<Model<B>, B>, O: Optimizer<Model<B>, B>,
{ {
pub fn step2(&mut self, _batch: MNISTBatch<B>) { pub fn step2(&mut self, _batch: MnistBatch<B>) {
// //
} }
} }
#[allow(dead_code)] #[allow(dead_code)]
impl<M, O> Learner2<M, O> { impl<M, O> Learner2<M, O> {
pub fn step3<B: AutodiffBackend>(&mut self, _batch: MNISTBatch<B>) pub fn step3<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
where where
B: AutodiffBackend, B: AutodiffBackend,
M: AutodiffModule<B>, M: AutodiffModule<B>,

View File

@ -18,7 +18,7 @@ fn main() {
guide::inference::infer::<MyBackend>( guide::inference::infer::<MyBackend>(
artifact_dir, artifact_dir,
device, device,
burn::data::dataset::vision::MNISTDataset::test() burn::data::dataset::vision::MnistDataset::test()
.get(42) .get(42)
.unwrap(), .unwrap(),
); );

View File

@ -1,26 +1,26 @@
use burn::{ use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem}, data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
}; };
pub struct MNISTBatcher<B: Backend> { pub struct MnistBatcher<B: Backend> {
device: B::Device, device: B::Device,
} }
impl<B: Backend> MNISTBatcher<B> { impl<B: Backend> MnistBatcher<B> {
pub fn new(device: B::Device) -> Self { pub fn new(device: B::Device) -> Self {
Self { device } Self { device }
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MNISTBatch<B: Backend> { pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 3>, pub images: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>, pub targets: Tensor<B, 1, Int>,
} }
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> { impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> { fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items let images = items
.iter() .iter()
.map(|item| Data::<f32, 2>::from(item.image)) .map(|item| Data::<f32, 2>::from(item.image))
@ -40,6 +40,6 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
let images = Tensor::cat(images, 0); let images = Tensor::cat(images, 0);
let targets = Tensor::cat(targets, 0); let targets = Tensor::cat(targets, 0);
MNISTBatch { images, targets } MnistBatch { images, targets }
} }
} }

View File

@ -1,5 +1,5 @@
use crate::{data::MNISTBatcher, training::TrainingConfig}; use crate::{data::MnistBatcher, training::TrainingConfig};
use burn::data::dataset::vision::MNISTItem; use burn::data::dataset::vision::MnistItem;
use burn::{ use burn::{
config::Config, config::Config,
data::dataloader::batcher::Batcher, data::dataloader::batcher::Batcher,
@ -7,7 +7,7 @@ use burn::{
tensor::backend::Backend, tensor::backend::Backend,
}; };
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) { pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model"); .expect("Config should exist for the model");
let record = CompactRecorder::new() let record = CompactRecorder::new()
@ -17,7 +17,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
let model = config.model.init_with::<B>(record); let model = config.model.init_with::<B>(record);
let label = item.label; let label = item.label;
let batcher = MNISTBatcher::new(device); let batcher = MnistBatcher::new(device);
let batch = batcher.batch(vec![item]); let batch = batcher.batch(vec![item]);
let output = model.forward(batch.images); let output = model.forward(batch.images);
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();

View File

@ -4,7 +4,7 @@ use burn::{
nn::{ nn::{
conv::{Conv2d, Conv2dConfig}, conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, ReLU, Dropout, DropoutConfig, Linear, LinearConfig, Relu,
}, },
tensor::{backend::Backend, Tensor}, tensor::{backend::Backend, Tensor},
}; };
@ -17,7 +17,7 @@ pub struct Model<B: Backend> {
dropout: Dropout, dropout: Dropout,
linear1: Linear<B>, linear1: Linear<B>,
linear2: Linear<B>, linear2: Linear<B>,
activation: ReLU, activation: Relu,
} }
#[derive(Config, Debug)] #[derive(Config, Debug)]
@ -35,7 +35,7 @@ impl ModelConfig {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(), activation: Relu::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
dropout: DropoutConfig::new(self.dropout).init(), dropout: DropoutConfig::new(self.dropout).init(),
@ -47,7 +47,7 @@ impl ModelConfig {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(), activation: Relu::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
linear2: LinearConfig::new(self.hidden_size, self.num_classes) linear2: LinearConfig::new(self.hidden_size, self.num_classes)
.init_with(record.linear2), .init_with(record.linear2),

View File

@ -1,8 +1,8 @@
use crate::{ use crate::{
data::{MNISTBatch, MNISTBatcher}, data::{MnistBatch, MnistBatcher},
model::{Model, ModelConfig}, model::{Model, ModelConfig},
}; };
use burn::data::dataset::vision::MNISTDataset; use burn::data::dataset::vision::MnistDataset;
use burn::train::{ use burn::train::{
metric::{AccuracyMetric, LossMetric}, metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
@ -36,16 +36,16 @@ impl<B: Backend> Model<B> {
} }
} }
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> { impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> { fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.images, batch.targets); let item = self.forward_classification(batch.images, batch.targets);
TrainOutput::new(self, item.loss.backward(), item) TrainOutput::new(self, item.loss.backward(), item)
} }
} }
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> { impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> { fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.images, batch.targets) self.forward_classification(batch.images, batch.targets)
} }
} }
@ -74,20 +74,20 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
B::seed(config.seed); B::seed(config.seed);
let batcher_train = MNISTBatcher::<B>::new(device.clone()); let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone()); let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train) let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::train()); .build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid) let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::test()); .build(MnistDataset::test());
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(AccuracyMetric::new()) .metric_train_numeric(AccuracyMetric::new())

View File

@ -1,26 +1,26 @@
use burn::{ use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem}, data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
}; };
pub struct MNISTBatcher<B: Backend> { pub struct MnistBatcher<B: Backend> {
device: B::Device, device: B::Device,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MNISTBatch<B: Backend> { pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 3>, pub images: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>, pub targets: Tensor<B, 1, Int>,
} }
impl<B: Backend> MNISTBatcher<B> { impl<B: Backend> MnistBatcher<B> {
pub fn new(device: B::Device) -> Self { pub fn new(device: B::Device) -> Self {
Self { device } Self { device }
} }
} }
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> { impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> { fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items let images = items
.iter() .iter()
.map(|item| Data::<f32, 2>::from(item.image)) .map(|item| Data::<f32, 2>::from(item.image))
@ -45,6 +45,6 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
let images = Tensor::cat(images, 0); let images = Tensor::cat(images, 0);
let targets = Tensor::cat(targets, 0); let targets = Tensor::cat(targets, 0);
MNISTBatch { images, targets } MnistBatch { images, targets }
} }
} }

View File

@ -1,4 +1,4 @@
use crate::data::MNISTBatch; use crate::data::MnistBatch;
use burn::{ use burn::{
module::Module, module::Module,
nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d}, nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
@ -73,7 +73,7 @@ impl<B: Backend> Model<B> {
self.fc2.forward(x) self.fc2.forward(x)
} }
pub fn forward_classification(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> { pub fn forward_classification(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
let targets = item.targets; let targets = item.targets;
let output = self.forward(item.images); let output = self.forward(item.images);
let loss = CrossEntropyLossConfig::new() let loss = CrossEntropyLossConfig::new()
@ -117,16 +117,16 @@ impl<B: Backend> ConvBlock<B> {
} }
} }
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> { impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, item: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> { fn step(&self, item: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(item); let item = self.forward_classification(item);
TrainOutput::new(self, item.loss.backward(), item) TrainOutput::new(self, item.loss.backward(), item)
} }
} }
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> { impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> { fn step(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(item) self.forward_classification(item)
} }
} }

View File

@ -1,4 +1,4 @@
use crate::data::MNISTBatcher; use crate::data::MnistBatcher;
use crate::model::Model; use crate::model::Model;
use burn::module::Module; use burn::module::Module;
@ -10,7 +10,7 @@ use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition}; use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
use burn::{ use burn::{
config::Config, config::Config,
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset}, data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
tensor::backend::AutodiffBackend, tensor::backend::AutodiffBackend,
train::{ train::{
metric::{AccuracyMetric, LossMetric}, metric::{AccuracyMetric, LossMetric},
@ -44,19 +44,19 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
B::seed(config.seed); B::seed(config.seed);
// Data // Data
let batcher_train = MNISTBatcher::<B>::new(device.clone()); let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone()); let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train) let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::train()); .build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid) let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size) .batch_size(config.batch_size)
.shuffle(config.seed) .shuffle(config.seed)
.num_workers(config.num_workers) .num_workers(config.num_workers)
.build(MNISTDataset::test()); .build(MnistDataset::test());
// Model // Model
let learner = LearnerBuilder::new(ARTIFACT_DIR) let learner = LearnerBuilder::new(ARTIFACT_DIR)

View File

@ -3,7 +3,7 @@ use std::env::args;
use burn::backend::ndarray::NdArray; use burn::backend::ndarray::NdArray;
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn::data::dataset::vision::MNISTDataset; use burn::data::dataset::vision::MnistDataset;
use burn::data::dataset::Dataset; use burn::data::dataset::Dataset;
use onnx_inference::mnist::Model; use onnx_inference::mnist::Model;
@ -34,7 +34,7 @@ fn main() {
let model: Model<Backend> = Model::default(); let model: Model<Backend> = Model::default();
// Load the MNIST dataset and get an item // Load the MNIST dataset and get an item
let dataset = MNISTDataset::test(); let dataset = MnistDataset::test();
let item = dataset.get(image_index).unwrap(); let item = dataset.get(image_index).unwrap();
// Create a tensor from the image data // Create a tensor from the image data

View File

@ -5,7 +5,7 @@ use burn::backend::ndarray::NdArray;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}; use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn::data::dataset::vision::MNISTDataset; use burn::data::dataset::vision::MnistDataset;
use burn::data::dataset::Dataset; use burn::data::dataset::Dataset;
use model::Model; use model::Model;
@ -42,7 +42,7 @@ fn main() {
let model: Model<Backend> = Model::new_with(record); let model: Model<Backend> = Model::new_with(record);
// Load the MNIST dataset and get an item // Load the MNIST dataset and get an item
let dataset = MNISTDataset::test(); let dataset = MnistDataset::test();
let item = dataset.get(image_index).unwrap(); let item = dataset.get(image_index).unwrap();
// Create a tensor from the image data // Create a tensor from the image data

View File

@ -1,10 +1,10 @@
use crate::dataset::DiabetesBatch; use crate::dataset::DiabetesBatch;
use burn::config::Config; use burn::config::Config;
use burn::nn::loss::Reduction::Mean; use burn::nn::loss::Reduction::Mean;
use burn::nn::ReLU; use burn::nn::Relu;
use burn::{ use burn::{
module::Module, module::Module,
nn::{loss::MSELoss, Linear, LinearConfig}, nn::{loss::MseLoss, Linear, LinearConfig},
tensor::{ tensor::{
backend::{AutodiffBackend, Backend}, backend::{AutodiffBackend, Backend},
Tensor, Tensor,
@ -16,7 +16,7 @@ use burn::{
pub struct RegressionModel<B: Backend> { pub struct RegressionModel<B: Backend> {
input_layer: Linear<B>, input_layer: Linear<B>,
output_layer: Linear<B>, output_layer: Linear<B>,
activation: ReLU, activation: Relu,
} }
#[derive(Config)] #[derive(Config)]
@ -39,7 +39,7 @@ impl RegressionModelConfig {
RegressionModel { RegressionModel {
input_layer, input_layer,
output_layer, output_layer,
activation: ReLU::new(), activation: Relu::new(),
} }
} }
} }
@ -56,7 +56,7 @@ impl<B: Backend> RegressionModel<B> {
let targets: Tensor<B, 2> = item.targets.unsqueeze(); let targets: Tensor<B, 2> = item.targets.unsqueeze();
let output: Tensor<B, 2> = self.forward(item.inputs); let output: Tensor<B, 2> = self.forward(item.inputs);
let loss = MSELoss::new().forward(output.clone(), targets.clone(), Mean); let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);
RegressionOutput { RegressionOutput {
loss, loss,

View File

@ -19,7 +19,7 @@ use burn::{
record::{CompactRecorder, Recorder}, record::{CompactRecorder, Recorder},
tensor::backend::AutodiffBackend, tensor::backend::AutodiffBackend,
train::{ train::{
metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
LearnerBuilder, LearnerBuilder,
}, },
}; };
@ -91,8 +91,8 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
// Initialize learner // Initialize learner
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train(CUDAMetric::new()) .metric_train(CudaMetric::new())
.metric_valid(CUDAMetric::new()) .metric_valid(CudaMetric::new())
.metric_train_numeric(AccuracyMetric::new()) .metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new()) .metric_train_numeric(LossMetric::new())

View File

@ -13,7 +13,7 @@ use burn::{
record::{CompactRecorder, DefaultRecorder, Recorder}, record::{CompactRecorder, DefaultRecorder, Recorder},
tensor::backend::AutodiffBackend, tensor::backend::AutodiffBackend,
train::{ train::{
metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric}, metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
LearnerBuilder, LearnerBuilder,
}, },
}; };
@ -68,8 +68,8 @@ pub fn train<B: AutodiffBackend, D: Dataset<TextGenerationItem> + 'static>(
.init(); .init();
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train(CUDAMetric::new()) .metric_train(CudaMetric::new())
.metric_valid(CUDAMetric::new()) .metric_valid(CudaMetric::new())
.metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
.metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
.metric_train(LossMetric::new()) .metric_train(LossMetric::new())