mirror of https://github.com/tracel-ai/burn.git
Make all struct CamelCase (#1316)
This commit is contained in:
parent
dfb739c89a
commit
44266d5fd4
|
@ -194,7 +194,7 @@ impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
|
|||
```
|
||||
|
||||
Subsequently, we'll go into implementing our custom backend trait for the WGPU backend.
|
||||
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
|
||||
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
|
||||
we implement the trait for the raw `WgpuBackend` type.
|
||||
|
||||
```rust, ignore
|
||||
|
|
|
@ -16,15 +16,15 @@ at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/exa
|
|||
|
||||
```rust , ignore
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
|
||||
pub struct MNISTBatcher<B: Backend> {
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend> MNISTBatcher<B> {
|
||||
impl<B: Backend> MnistBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
|
@ -42,13 +42,13 @@ Next, we need to actually implement the batching logic.
|
|||
|
||||
```rust , ignore
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MNISTBatch<B: Backend> {
|
||||
pub struct MnistBatch<B: Backend> {
|
||||
pub images: Tensor<B, 3>,
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
||||
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
.map(|item| Data::<f32, 2>::from(item.image))
|
||||
|
@ -71,7 +71,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
let images = Tensor::cat(images, 0).to_device(&self.device);
|
||||
let targets = Tensor::cat(targets, 0).to_device(&self.device);
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
MnistBatch { images, targets }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
@ -81,7 +81,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
|
||||
The iterator pattern allows you to perform some tasks on a sequence of items in turn.
|
||||
|
||||
In this example, an iterator is created over the `MNISTItem`s in the vector `items` by calling the
|
||||
In this example, an iterator is created over the `MnistItem`s in the vector `items` by calling the
|
||||
`iter` method.
|
||||
|
||||
_Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by
|
||||
|
@ -100,7 +100,7 @@ If we go back to the example, we can break down and comment the expression used
|
|||
images.
|
||||
|
||||
```rust, ignore
|
||||
let images = items // take items Vec<MNISTItem>
|
||||
let images = items // take items Vec<MnistItem>
|
||||
.iter() // create an iterator over it
|
||||
.map(|item| Data::<f32, 2>::from(item.image)) // for each item, convert the image to float32 data struct
|
||||
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device
|
||||
|
@ -115,8 +115,8 @@ Book.
|
|||
|
||||
</details><br>
|
||||
|
||||
In the previous example, we implement the `Batcher` trait with a list of `MNISTItem` as input and a
|
||||
single `MNISTBatch` as output. The batch contains the images in the form of a 3D tensor, along with
|
||||
In the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a
|
||||
single `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with
|
||||
a targets tensor that contains the indexes of the correct digit class. The first step is to parse
|
||||
the image array into a `Data` struct. Burn provides the `Data` struct to encapsulate tensor storage
|
||||
information without being specific for a backend. When creating a tensor from data, we often need to
|
||||
|
|
|
@ -16,7 +16,7 @@ impl ModelConfig {
|
|||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
|
||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
|
||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
|
||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
|
||||
.init_with(record.linear2),
|
||||
|
@ -33,7 +33,7 @@ manually. Everything is validated when loading the model with the record.
|
|||
Now let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model.
|
||||
|
||||
```rust , ignore
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
let record = CompactRecorder::new()
|
||||
|
@ -43,7 +43,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
|||
let model = config.model.init_with::<B>(record);
|
||||
|
||||
let label = item.label;
|
||||
let batcher = MNISTBatcher::new(device);
|
||||
let batcher = MnistBatcher::new(device);
|
||||
let batch = batcher.batch(vec![item]);
|
||||
let output = model.forward(batch.images);
|
||||
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
||||
|
@ -56,6 +56,6 @@ The first step is to load the configuration of the training to fetch the correct
|
|||
configuration. Then we can fetch the record using the same recorder as we used during training.
|
||||
Finally we can init the model with the configuration and the record before sending it to the wanted
|
||||
device for inference. For simplicity we can use the same batcher used during the training to pass
|
||||
from a MNISTItem to a tensor.
|
||||
from a MnistItem to a tensor.
|
||||
|
||||
By running the infer function, you should see the predictions of your model!
|
||||
|
|
|
@ -35,7 +35,7 @@ use burn::{
|
|||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, ReLU,
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
@ -48,7 +48,7 @@ pub struct Model<B: Backend> {
|
|||
dropout: Dropout,
|
||||
linear1: Linear<B>,
|
||||
linear2: Linear<B>,
|
||||
activation: ReLU,
|
||||
activation: Relu,
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -98,7 +98,7 @@ There are two major things going on in this code sample.
|
|||
pub struct MyCustomModule<B: Backend> {
|
||||
linear1: Linear<B>,
|
||||
linear2: Linear<B>,
|
||||
activation: ReLU,
|
||||
activation: Relu,
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -178,7 +178,7 @@ impl ModelConfig {
|
|||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
|
|
|
@ -43,23 +43,23 @@ Moving forward, we will proceed with the implementation of both the training and
|
|||
for our model.
|
||||
|
||||
```rust , ignore
|
||||
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward_classification(batch.images, batch.targets);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward_classification(batch.images, batch.targets)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Here we define the input and output types as generic arguments in the `TrainStep` and `ValidStep`.
|
||||
We will call them `MNISTBatch` and `ClassificationOutput`. In the training step, the computation of
|
||||
We will call them `MnistBatch` and `ClassificationOutput`. In the training step, the computation of
|
||||
gradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note
|
||||
that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather
|
||||
returned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a
|
||||
|
@ -81,8 +81,8 @@ which is generic over the `Backend` trait as has been covered before. These trai
|
|||
`burn::train` and define a common `step` method that should be implemented for all structs. Since
|
||||
the trait is generic over the input and output types, the trait implementation must specify the
|
||||
concrete types used. This is where the additional type constraints appear
|
||||
`<MNISTBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the
|
||||
batch is `MNISTBatch`, and the output of the forward pass is `ClassificationOutput`. The `step`
|
||||
`<MnistBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the
|
||||
batch is `MnistBatch`, and the output of the forward pass is `ClassificationOutput`. The `step`
|
||||
method signature matches the concrete input and output types.
|
||||
|
||||
For more details specific to constraints on generic types when defining methods, take a look at
|
||||
|
@ -118,20 +118,20 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
|
||||
B::seed(config.seed);
|
||||
|
||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::train());
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::test());
|
||||
.build(MnistDataset::test());
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
|
|
|
@ -160,4 +160,4 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------ | --------------------- |
|
||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||
| `MSELoss` | `nn.MSELoss` |
|
||||
| `MseLoss` | `nn.MSELoss` |
|
||||
|
|
|
@ -40,21 +40,21 @@ pub fn run<B: AutodiffBackend>(device: &B::Device) {
|
|||
let mut optim = config.optimizer.init();
|
||||
|
||||
// Create the batcher.
|
||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
// Create the dataloaders.
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::train());
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::test());
|
||||
.build(MnistDataset::test());
|
||||
|
||||
...
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ Note that after each epoch, we include a validation loop to assess our model's p
|
|||
previously unseen data. To disable gradient tracking during this validation step, we can invoke
|
||||
`model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's
|
||||
important to emphasize that we've declared our validation batcher to be on the inner backend,
|
||||
specifically `MNISTBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation
|
||||
specifically `MnistBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation
|
||||
error.
|
||||
|
||||
You can find the code above available as an
|
||||
|
@ -195,7 +195,7 @@ where
|
|||
M: AutodiffModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
pub fn step(&mut self, _batch: MNISTBatch<B>) {
|
||||
pub fn step(&mut self, _batch: MnistBatch<B>) {
|
||||
//
|
||||
}
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ the backend and add your trait constraint within its definition:
|
|||
```rust, ignore
|
||||
#[allow(dead_code)]
|
||||
impl<M, O> Learner2<M, O> {
|
||||
pub fn step<B: AutodiffBackend>(&mut self, _batch: MNISTBatch<B>)
|
||||
pub fn step<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
|
|
|
@ -44,7 +44,7 @@ model definition as a simple example.
|
|||
pub struct Model<B: Backend> {
|
||||
linear_in: Linear<B>,
|
||||
linear_out: Linear<B>,
|
||||
activation: ReLU,
|
||||
activation: Relu,
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -59,7 +59,7 @@ impl<B: Backend> Model<B> {
|
|||
Model {
|
||||
linear_in: LinearConfig::new(10, 64).init_with(record.linear_in),
|
||||
linear_out: LinearConfig::new(64, 2).init_with(record.linear_out),
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,7 +70,7 @@ impl<B: Backend> Model<B> {
|
|||
Model {
|
||||
linear_in: l1,
|
||||
linear_out: l2,
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,17 +5,17 @@ use burn_tensor::{backend::Backend, Tensor};
|
|||
|
||||
/// Calculate the mean squared error loss from the input logits and the targets.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MSELoss<B: Backend> {
|
||||
pub struct MseLoss<B: Backend> {
|
||||
backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for MSELoss<B> {
|
||||
impl<B: Backend> Default for MseLoss<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MSELoss<B> {
|
||||
impl<B: Backend> MseLoss<B> {
|
||||
/// Create the criterion.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
@ -67,7 +67,7 @@ mod tests {
|
|||
let targets =
|
||||
Tensor::<TestBackend, 2>::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]]), &device);
|
||||
|
||||
let mse = MSELoss::new();
|
||||
let mse = MseLoss::new();
|
||||
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
|
||||
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_sum = mse.forward(logits, targets, Reduction::Sum);
|
||||
|
|
|
@ -8,9 +8,9 @@ use crate::tensor::Tensor;
|
|||
///
|
||||
/// `y = max(0, x)`
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct ReLU {}
|
||||
pub struct Relu {}
|
||||
|
||||
impl ReLU {
|
||||
impl Relu {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
|
|
|
@ -27,14 +27,14 @@ pub struct AdaGradConfig {
|
|||
|
||||
/// AdaGrad optimizer
|
||||
pub struct AdaGrad<B: Backend> {
|
||||
lr_decay: LRDecay,
|
||||
lr_decay: LrDecay,
|
||||
weight_decay: Option<WeightDecay<B>>,
|
||||
}
|
||||
|
||||
/// AdaGrad state.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct AdaGradState<B: Backend, const D: usize> {
|
||||
lr_decay: LRDecayState<B, D>,
|
||||
lr_decay: LrDecayState<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for AdaGrad<B> {
|
||||
|
@ -81,7 +81,7 @@ impl AdaGradConfig {
|
|||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> impl Optimizer<M, B> {
|
||||
let optim = AdaGrad {
|
||||
lr_decay: LRDecay {
|
||||
lr_decay: LrDecay {
|
||||
lr_decay: self.lr_decay,
|
||||
epsilon: self.epsilon,
|
||||
},
|
||||
|
@ -98,29 +98,29 @@ impl AdaGradConfig {
|
|||
|
||||
/// Learning rate decay state (also includes sum state).
|
||||
#[derive(Record, new, Clone)]
|
||||
pub struct LRDecayState<B: Backend, const D: usize> {
|
||||
pub struct LrDecayState<B: Backend, const D: usize> {
|
||||
time: usize,
|
||||
sum: Tensor<B, D>,
|
||||
}
|
||||
|
||||
struct LRDecay {
|
||||
struct LrDecay {
|
||||
lr_decay: f64,
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl LRDecay {
|
||||
impl LrDecay {
|
||||
pub fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
lr: LearningRate,
|
||||
lr_decay_state: Option<LRDecayState<B, D>>,
|
||||
) -> (Tensor<B, D>, LRDecayState<B, D>) {
|
||||
lr_decay_state: Option<LrDecayState<B, D>>,
|
||||
) -> (Tensor<B, D>, LrDecayState<B, D>) {
|
||||
let state = if let Some(mut state) = lr_decay_state {
|
||||
state.sum = state.sum.add(grad.clone().powf_scalar(2.));
|
||||
state.time += 1;
|
||||
state
|
||||
} else {
|
||||
LRDecayState::new(1, grad.clone().powf_scalar(2.))
|
||||
LrDecayState::new(1, grad.clone().powf_scalar(2.))
|
||||
};
|
||||
|
||||
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
||||
|
@ -133,7 +133,7 @@ impl LRDecay {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> LRDecayState<B, D> {
|
||||
impl<B: Backend, const D: usize> LrDecayState<B, D> {
|
||||
/// Move state to device.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -278,7 +278,7 @@ mod tests {
|
|||
{
|
||||
let config = AdaGradConfig::new();
|
||||
AdaGrad {
|
||||
lr_decay: LRDecay {
|
||||
lr_decay: LrDecay {
|
||||
lr_decay: config.lr_decay,
|
||||
epsilon: config.epsilon,
|
||||
},
|
||||
|
|
|
@ -12,19 +12,19 @@ use crate::optim::adaptor::OptimizerAdaptor;
|
|||
use crate::tensor::{backend::AutodiffBackend, Tensor};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
/// Configuration to create the [RMSProp](RMSProp) optimizer.
|
||||
/// Configuration to create the [RmsProp](RmsProp) optimizer.
|
||||
#[derive(Config)]
|
||||
pub struct RMSPropConfig {
|
||||
pub struct RmsPropConfig {
|
||||
/// Smoothing constant.
|
||||
#[config(default = 0.99)]
|
||||
alpha: f32,
|
||||
/// momentum for RMSProp.
|
||||
/// momentum for RmsProp.
|
||||
#[config(default = 0.9)]
|
||||
momentum: f32,
|
||||
/// A value required for numerical stability.
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f32,
|
||||
/// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
|
||||
/// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
|
||||
#[config(default = false)]
|
||||
centered: bool,
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
|
@ -33,22 +33,22 @@ pub struct RMSPropConfig {
|
|||
grad_clipping: Option<GradientClippingConfig>,
|
||||
}
|
||||
|
||||
impl RMSPropConfig {
|
||||
/// Initialize RMSProp optimizer.
|
||||
impl RmsPropConfig {
|
||||
/// Initialize RmsProp optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
&self,
|
||||
) -> OptimizerAdaptor<RMSProp<B::InnerBackend>, M, B> {
|
||||
) -> OptimizerAdaptor<RmsProp<B::InnerBackend>, M, B> {
|
||||
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
|
||||
|
||||
let mut optim = OptimizerAdaptor::from(RMSProp {
|
||||
let mut optim = OptimizerAdaptor::from(RmsProp {
|
||||
alpha: self.alpha,
|
||||
centered: self.centered,
|
||||
weight_decay,
|
||||
momentum: RMSPropMomentum {
|
||||
momentum: RmsPropMomentum {
|
||||
momentum: self.momentum,
|
||||
epsilon: self.epsilon,
|
||||
},
|
||||
|
@ -63,18 +63,18 @@ impl RMSPropConfig {
|
|||
}
|
||||
|
||||
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||
/// The optimizer can be configured with [RMSPropConfig](RMSPropConfig).
|
||||
pub struct RMSProp<B: Backend> {
|
||||
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
|
||||
pub struct RmsProp<B: Backend> {
|
||||
alpha: f32,
|
||||
// epsilon: f32,
|
||||
centered: bool,
|
||||
// momentum: Option<Momentum<B>>,
|
||||
momentum: RMSPropMomentum,
|
||||
momentum: RmsPropMomentum,
|
||||
weight_decay: Option<WeightDecay<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
||||
type State<const D: usize> = RMSPropState<B, D>;
|
||||
impl<B: Backend> SimpleOptimizer<B> for RmsProp<B> {
|
||||
type State<const D: usize> = RmsPropState<B, D>;
|
||||
|
||||
fn step<const D: usize>(
|
||||
&self,
|
||||
|
@ -117,7 +117,7 @@ impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
|||
.transform(grad, state_centered, state_momentum);
|
||||
|
||||
// transition state
|
||||
let state = RMSPropState::new(state_square_avg, state_centered, state_momentum);
|
||||
let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
|
||||
|
||||
// tensor param transform
|
||||
let delta = grad.mul_scalar(lr);
|
||||
|
@ -135,12 +135,12 @@ impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// State of [RMSProp](RMSProp)
|
||||
/// State of [RmsProp](RmsProp)
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct RMSPropState<B: Backend, const D: usize> {
|
||||
pub struct RmsPropState<B: Backend, const D: usize> {
|
||||
square_avg: SquareAvgState<B, D>,
|
||||
centered: CenteredState<B, D>,
|
||||
momentum: Option<RMSPropMomentumState<B, D>>,
|
||||
momentum: Option<RmsPropMomentumState<B, D>>,
|
||||
}
|
||||
|
||||
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
|
||||
|
@ -249,24 +249,24 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
|
|||
}
|
||||
}
|
||||
|
||||
/// [RMSPropMomentum](RMSPropMomentum) is to store config status for optimizer.
|
||||
/// (, which is stored in [optimizer](RMSProp) itself and not passed in during `step()` calculation)
|
||||
pub struct RMSPropMomentum {
|
||||
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
|
||||
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
|
||||
pub struct RmsPropMomentum {
|
||||
momentum: f32,
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl RMSPropMomentum {
|
||||
/// transform [grad](Tensor) and [RMSPropMomentumState] to the next step
|
||||
impl RmsPropMomentum {
|
||||
/// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
|
||||
fn transform<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
centered_state: CenteredState<B, D>,
|
||||
momentum_state: Option<RMSPropMomentumState<B, D>>,
|
||||
momentum_state: Option<RmsPropMomentumState<B, D>>,
|
||||
) -> (
|
||||
Tensor<B, D>,
|
||||
CenteredState<B, D>,
|
||||
Option<RMSPropMomentumState<B, D>>,
|
||||
Option<RmsPropMomentumState<B, D>>,
|
||||
) {
|
||||
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||
|
||||
|
@ -278,7 +278,7 @@ impl RMSPropMomentum {
|
|||
(
|
||||
buf.clone(),
|
||||
centered_state,
|
||||
Some(RMSPropMomentumState { buf }),
|
||||
Some(RmsPropMomentumState { buf }),
|
||||
)
|
||||
} else {
|
||||
(grad, centered_state, None)
|
||||
|
@ -286,13 +286,13 @@ impl RMSPropMomentum {
|
|||
}
|
||||
}
|
||||
|
||||
/// [RMSPropMomentumState](RMSPropMomentumState) is to store and pass optimizer step params.
|
||||
/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct RMSPropMomentumState<B: Backend, const D: usize> {
|
||||
pub struct RmsPropMomentumState<B: Backend, const D: usize> {
|
||||
buf: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> RMSPropMomentumState<B, D> {
|
||||
impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -378,7 +378,7 @@ mod tests {
|
|||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = RMSPropConfig::new()
|
||||
let mut optimizer = RmsPropConfig::new()
|
||||
.with_alpha(0.99)
|
||||
.with_epsilon(1e-8)
|
||||
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||
|
@ -453,7 +453,7 @@ mod tests {
|
|||
)
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = RMSPropConfig::new()
|
||||
let mut optimizer = RmsPropConfig::new()
|
||||
.with_alpha(0.99)
|
||||
.with_epsilon(1e-8)
|
||||
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||
|
@ -529,9 +529,9 @@ mod tests {
|
|||
}
|
||||
|
||||
fn create_rmsprop(
|
||||
) -> OptimizerAdaptor<RMSProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
|
||||
) -> OptimizerAdaptor<RmsProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
|
||||
{
|
||||
RMSPropConfig {
|
||||
RmsPropConfig {
|
||||
alpha: 0.99,
|
||||
epsilon: 1e-9,
|
||||
centered: false,
|
||||
|
|
|
@ -49,12 +49,12 @@ pub enum ImporterError {
|
|||
/// use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// #[derive(Deserialize, Debug, Clone)]
|
||||
/// struct MNISTItemRaw {
|
||||
/// struct MnistItemRaw {
|
||||
/// pub image_bytes: Vec<u8>,
|
||||
/// pub label: usize,
|
||||
/// }
|
||||
///
|
||||
/// let train_ds:SqliteDataset<MNISTItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||
/// let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||
/// .dataset("train")
|
||||
/// .unwrap();
|
||||
pub struct HuggingfaceDatasetLoader {
|
||||
|
|
|
@ -24,7 +24,7 @@ const HEIGHT: usize = 28;
|
|||
|
||||
/// MNIST item.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MNISTItem {
|
||||
pub struct MnistItem {
|
||||
/// Image as a 2D array of floats.
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
|
||||
|
@ -33,16 +33,16 @@ pub struct MNISTItem {
|
|||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MNISTItemRaw {
|
||||
struct MnistItemRaw {
|
||||
pub image_bytes: Vec<u8>,
|
||||
pub label: u8,
|
||||
}
|
||||
|
||||
struct BytesToImage;
|
||||
|
||||
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
||||
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
|
||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
|
||||
fn map(&self, item: &MnistItemRaw) -> MnistItem {
|
||||
// Ensure the image dimensions are correct.
|
||||
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
||||
|
||||
|
@ -54,25 +54,25 @@ impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
|||
image_array[y][x] = *pixel as f32;
|
||||
}
|
||||
|
||||
MNISTItem {
|
||||
MnistItem {
|
||||
image: image_array,
|
||||
label: item.label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MappedDataset = MapperDataset<InMemDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
|
||||
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
|
||||
|
||||
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
||||
/// images per class. There are 60,000 training images and 10,000 test images.
|
||||
///
|
||||
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
||||
pub struct MNISTDataset {
|
||||
pub struct MnistDataset {
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
|
||||
impl Dataset<MNISTItem> for MNISTDataset {
|
||||
fn get(&self, index: usize) -> Option<MNISTItem> {
|
||||
impl Dataset<MnistItem> for MnistDataset {
|
||||
fn get(&self, index: usize) -> Option<MnistItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
|
@ -81,7 +81,7 @@ impl Dataset<MNISTItem> for MNISTDataset {
|
|||
}
|
||||
}
|
||||
|
||||
impl MNISTDataset {
|
||||
impl MnistDataset {
|
||||
/// Creates a new train dataset.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
|
@ -94,19 +94,19 @@ impl MNISTDataset {
|
|||
|
||||
fn new(split: &str) -> Self {
|
||||
// Download dataset
|
||||
let root = MNISTDataset::download(split);
|
||||
let root = MnistDataset::download(split);
|
||||
|
||||
// MNIST is tiny so we can load it in-memory
|
||||
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
|
||||
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
|
||||
let images = MNISTDataset::read_images(&root, split);
|
||||
let labels = MNISTDataset::read_labels(&root, split);
|
||||
let images = MnistDataset::read_images(&root, split);
|
||||
let labels = MnistDataset::read_labels(&root, split);
|
||||
|
||||
// Collect as vector of MNISTItemRaw
|
||||
// Collect as vector of MnistItemRaw
|
||||
let items: Vec<_> = images
|
||||
.into_iter()
|
||||
.zip(labels)
|
||||
.map(|(image_bytes, label)| MNISTItemRaw { image_bytes, label })
|
||||
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
|
||||
.collect();
|
||||
|
||||
let dataset = InMemDataset::new(items);
|
||||
|
@ -132,12 +132,12 @@ impl MNISTDataset {
|
|||
// Download split files
|
||||
match split {
|
||||
"train" => {
|
||||
MNISTDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||
MNISTDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||
MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||
MnistDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||
}
|
||||
"test" => {
|
||||
MNISTDataset::download_file(TEST_IMAGES, &split_dir);
|
||||
MNISTDataset::download_file(TEST_LABELS, &split_dir);
|
||||
MnistDataset::download_file(TEST_IMAGES, &split_dir);
|
||||
MnistDataset::download_file(TEST_LABELS, &split_dir);
|
||||
}
|
||||
_ => panic!("Invalid split specified {}", split),
|
||||
};
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig, ReLU},
|
||||
nn::{Linear, LinearConfig, Relu},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
|
@ -8,7 +8,7 @@ use burn::{
|
|||
pub struct Net<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
relu: ReLU,
|
||||
relu: Relu,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
|
@ -16,7 +16,7 @@ impl<B: Backend> Net<B> {
|
|||
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
|
||||
let fc2 = LinearConfig::new(3, 4).init_with(record.fc2);
|
||||
let relu = ReLU::default();
|
||||
let relu = Relu::default();
|
||||
|
||||
Self { fc1, fc2, relu }
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use crate::onnx::{
|
|||
proto_conversion::convert_node_proto,
|
||||
};
|
||||
|
||||
use super::ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor};
|
||||
use super::ir::{ArgType, Argument, Node, NodeType, OnnxGraph, Tensor};
|
||||
use super::protos::{ModelProto, TensorProto};
|
||||
use super::{dim_inference::dim_inference, protos::ValueInfoProto};
|
||||
|
||||
|
@ -33,14 +33,14 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [
|
|||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `ONNXGraph` - The graph representation of the onnx file
|
||||
/// * `OnnxGraph` - The graph representation of the onnx file
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the file cannot be opened
|
||||
/// * If the file cannot be parsed
|
||||
/// * If the nodes are not topologically sorted
|
||||
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
||||
pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
|
||||
log::info!("Parsing ONNX file: {}", onnx_path.display());
|
||||
|
||||
// Open the file
|
||||
|
@ -118,7 +118,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
|||
|
||||
log::info!("Finished parsing ONNX file: {}", onnx_path.display());
|
||||
|
||||
ONNXGraph {
|
||||
OnnxGraph {
|
||||
nodes,
|
||||
inputs,
|
||||
outputs,
|
||||
|
|
|
@ -129,7 +129,7 @@ pub enum Data {
|
|||
|
||||
/// ONNX graph representation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ONNXGraph {
|
||||
pub struct OnnxGraph {
|
||||
/// The nodes of the graph.
|
||||
pub nodes: Vec<Node>,
|
||||
|
||||
|
|
|
@ -11,4 +11,4 @@ mod to_burn;
|
|||
pub use to_burn::*;
|
||||
|
||||
pub use from_onnx::parse_onnx;
|
||||
pub use ir::ONNXGraph;
|
||||
pub use ir::OnnxGraph;
|
||||
|
|
|
@ -45,7 +45,7 @@ use crate::{
|
|||
|
||||
use super::{
|
||||
from_onnx::parse_onnx,
|
||||
ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph},
|
||||
ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph},
|
||||
op_configuration::{
|
||||
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
|
||||
softmax_config,
|
||||
|
@ -218,7 +218,7 @@ impl ModelGen {
|
|||
}
|
||||
}
|
||||
|
||||
impl ONNXGraph {
|
||||
impl OnnxGraph {
|
||||
/// Converts ONNX graph to Burn graph.
|
||||
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
||||
let mut graph = BurnGraph::<PS>::default();
|
||||
|
|
|
@ -28,7 +28,7 @@ pub struct MlpConfig {
|
|||
pub struct Mlp<B: Backend> {
|
||||
linears: Vec<nn::Linear<B>>,
|
||||
dropout: nn::Dropout,
|
||||
activation: nn::ReLU,
|
||||
activation: nn::Relu,
|
||||
}
|
||||
|
||||
impl<B: Backend> Mlp<B> {
|
||||
|
@ -43,7 +43,7 @@ impl<B: Backend> Mlp<B> {
|
|||
Self {
|
||||
linears,
|
||||
dropout: nn::DropoutConfig::new(0.3).init(),
|
||||
activation: nn::ReLU::new(),
|
||||
activation: nn::Relu::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ use crate::metric::{Metric, MetricEntry};
|
|||
use nvml_wrapper::Nvml;
|
||||
|
||||
/// Track basic cuda infos.
|
||||
pub struct CUDAMetric {
|
||||
pub struct CudaMetric {
|
||||
nvml: Option<Nvml>,
|
||||
}
|
||||
|
||||
impl CUDAMetric {
|
||||
impl CudaMetric {
|
||||
/// Creates a new metric for CUDA.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
@ -19,7 +19,7 @@ impl CUDAMetric {
|
|||
}
|
||||
}
|
||||
|
||||
impl Default for CUDAMetric {
|
||||
impl Default for CudaMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ impl<T> Adaptor<()> for T {
|
|||
fn adapt(&self) {}
|
||||
}
|
||||
|
||||
impl Metric for CUDAMetric {
|
||||
impl Metric for CudaMetric {
|
||||
const NAME: &'static str = "CUDA Stats";
|
||||
|
||||
type Input = ();
|
||||
|
|
|
@ -61,7 +61,7 @@
|
|||
//! - `audio`: Enables audio datasets (SpeechCommandsDataset)
|
||||
//! - `sqlite`: Stores datasets in SQLite database
|
||||
//! - `sqlite_bundled`: Use bundled version of SQLite
|
||||
//! - `vision`: Enables vision datasets (MNISTDataset)
|
||||
//! - `vision`: Enables vision datasets (MnistDataset)
|
||||
//! - Backends
|
||||
//! - `wgpu`: Makes available the WGPU backend
|
||||
//! - `candle`: Makes available the Candle backend
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn::{
|
|||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{MaxPool2d, MaxPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, ReLU,
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
|
||||
},
|
||||
tensor::{backend::Backend, Device, Tensor},
|
||||
};
|
||||
|
@ -23,8 +23,8 @@ use burn::{
|
|||
// │ maxpool │
|
||||
// └────────────────────┘
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CNN<B: Backend> {
|
||||
activation: ReLU,
|
||||
pub struct Cnn<B: Backend> {
|
||||
activation: Relu,
|
||||
dropout: Dropout,
|
||||
pool: MaxPool2d,
|
||||
conv1: Conv2d<B>,
|
||||
|
@ -37,7 +37,7 @@ pub struct CNN<B: Backend> {
|
|||
fc2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CNN<B> {
|
||||
impl<B: Backend> Cnn<B> {
|
||||
pub fn new(num_classes: usize, device: &Device<B>) -> Self {
|
||||
let conv1 = Conv2dConfig::new([3, 32], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Same)
|
||||
|
@ -68,7 +68,7 @@ impl<B: Backend> CNN<B> {
|
|||
let dropout = DropoutConfig::new(0.3).init();
|
||||
|
||||
Self {
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
dropout,
|
||||
pool,
|
||||
conv1,
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::time::Instant;
|
|||
use crate::{
|
||||
data::{ClassificationBatch, ClassificationBatcher},
|
||||
dataset::CIFAR10Loader,
|
||||
model::CNN,
|
||||
model::Cnn,
|
||||
};
|
||||
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
|
||||
use burn::train::{
|
||||
|
@ -26,7 +26,7 @@ use burn::{
|
|||
const NUM_CLASSES: u8 = 10;
|
||||
const ARTIFACT_DIR: &str = "/tmp/custom-image-dataset";
|
||||
|
||||
impl<B: Backend> CNN<B> {
|
||||
impl<B: Backend> Cnn<B> {
|
||||
pub fn forward_classification(
|
||||
&self,
|
||||
images: Tensor<B, 4>,
|
||||
|
@ -41,7 +41,7 @@ impl<B: Backend> CNN<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<B>> for CNN<B> {
|
||||
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<B>> for Cnn<B> {
|
||||
fn step(&self, batch: ClassificationBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward_classification(batch.images, batch.targets);
|
||||
|
||||
|
@ -49,7 +49,7 @@ impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<ClassificationBatch<B>, ClassificationOutput<B>> for CNN<B> {
|
||||
impl<B: Backend> ValidStep<ClassificationBatch<B>, ClassificationOutput<B>> for Cnn<B> {
|
||||
fn step(&self, batch: ClassificationBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward_classification(batch.images, batch.targets)
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
|
|||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(
|
||||
CNN::new(NUM_CLASSES.into(), &device),
|
||||
Cnn::new(NUM_CLASSES.into(), &device),
|
||||
config.optimizer.init(),
|
||||
config.learning_rate,
|
||||
);
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use burn::train::LearnerBuilder;
|
||||
use burn::{
|
||||
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
||||
tensor::backend::AutodiffBackend,
|
||||
};
|
||||
use guide::{data::MNISTBatcher, model::ModelConfig};
|
||||
use guide::{data::MnistBatcher, model::ModelConfig};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MnistTrainingConfig {
|
||||
|
@ -52,21 +52,21 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
let optim = config.optimizer.init();
|
||||
|
||||
// Create the batcher.
|
||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
// Create the dataloaders.
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::train());
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::test());
|
||||
.build(MnistDataset::test());
|
||||
|
||||
// artifact dir does not need to be provided when log_to_file is false
|
||||
let builder = LearnerBuilder::new("")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
|
@ -13,7 +13,7 @@ use burn::{
|
|||
},
|
||||
};
|
||||
use guide::{
|
||||
data::{MNISTBatch, MNISTBatcher},
|
||||
data::{MnistBatch, MnistBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
|
||||
|
@ -46,21 +46,21 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
let mut optim = config.optimizer.init();
|
||||
|
||||
// Create the batcher.
|
||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
// Create the dataloaders.
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::train());
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::test());
|
||||
.build(MnistDataset::test());
|
||||
|
||||
// Iterate over our training and validation loop for X epochs.
|
||||
for epoch in 1..config.num_epochs + 1 {
|
||||
|
@ -145,7 +145,7 @@ where
|
|||
B: AutodiffBackend,
|
||||
O: Optimizer<Model<B>, B>,
|
||||
{
|
||||
pub fn step1(&mut self, _batch: MNISTBatch<B>) {
|
||||
pub fn step1(&mut self, _batch: MnistBatch<B>) {
|
||||
//
|
||||
}
|
||||
}
|
||||
|
@ -156,14 +156,14 @@ where
|
|||
B: AutodiffBackend,
|
||||
O: Optimizer<Model<B>, B>,
|
||||
{
|
||||
pub fn step2(&mut self, _batch: MNISTBatch<B>) {
|
||||
pub fn step2(&mut self, _batch: MnistBatch<B>) {
|
||||
//
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<M, O> Learner2<M, O> {
|
||||
pub fn step3<B: AutodiffBackend>(&mut self, _batch: MNISTBatch<B>)
|
||||
pub fn step3<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
|
|
|
@ -18,7 +18,7 @@ fn main() {
|
|||
guide::inference::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::vision::MNISTDataset::test()
|
||||
burn::data::dataset::vision::MnistDataset::test()
|
||||
.get(42)
|
||||
.unwrap(),
|
||||
);
|
||||
|
|
|
@ -1,26 +1,26 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
|
||||
pub struct MNISTBatcher<B: Backend> {
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend> MNISTBatcher<B> {
|
||||
impl<B: Backend> MnistBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MNISTBatch<B: Backend> {
|
||||
pub struct MnistBatch<B: Backend> {
|
||||
pub images: Tensor<B, 3>,
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
||||
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
.map(|item| Data::<f32, 2>::from(item.image))
|
||||
|
@ -40,6 +40,6 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
let images = Tensor::cat(images, 0);
|
||||
let targets = Tensor::cat(targets, 0);
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
MnistBatch { images, targets }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{data::MNISTBatcher, training::TrainingConfig};
|
||||
use burn::data::dataset::vision::MNISTItem;
|
||||
use crate::{data::MnistBatcher, training::TrainingConfig};
|
||||
use burn::data::dataset::vision::MnistItem;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
|
@ -7,7 +7,7 @@ use burn::{
|
|||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should exist for the model");
|
||||
let record = CompactRecorder::new()
|
||||
|
@ -17,7 +17,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
|||
let model = config.model.init_with::<B>(record);
|
||||
|
||||
let label = item.label;
|
||||
let batcher = MNISTBatcher::new(device);
|
||||
let batcher = MnistBatcher::new(device);
|
||||
let batch = batcher.batch(vec![item]);
|
||||
let output = model.forward(batch.images);
|
||||
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
||||
|
|
|
@ -4,7 +4,7 @@ use burn::{
|
|||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, ReLU,
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
@ -17,7 +17,7 @@ pub struct Model<B: Backend> {
|
|||
dropout: Dropout,
|
||||
linear1: Linear<B>,
|
||||
linear2: Linear<B>,
|
||||
activation: ReLU,
|
||||
activation: Relu,
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
|
@ -35,7 +35,7 @@ impl ModelConfig {
|
|||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
|
@ -47,7 +47,7 @@ impl ModelConfig {
|
|||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
|
||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
|
||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
|
||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
|
||||
.init_with(record.linear2),
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{
|
||||
data::{MNISTBatch, MNISTBatcher},
|
||||
data::{MnistBatch, MnistBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
|
@ -36,16 +36,16 @@ impl<B: Backend> Model<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward_classification(batch.images, batch.targets);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward_classification(batch.images, batch.targets)
|
||||
}
|
||||
}
|
||||
|
@ -74,20 +74,20 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
|
||||
B::seed(config.seed);
|
||||
|
||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::train());
|
||||
.build(MnistDataset::train());
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::test());
|
||||
.build(MnistDataset::test());
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
|
|
|
@ -1,26 +1,26 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
|
||||
pub struct MNISTBatcher<B: Backend> {
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MNISTBatch<B: Backend> {
|
||||
pub struct MnistBatch<B: Backend> {
|
||||
pub images: Tensor<B, 3>,
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MNISTBatcher<B> {
|
||||
impl<B: Backend> MnistBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
||||
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
.map(|item| Data::<f32, 2>::from(item.image))
|
||||
|
@ -45,6 +45,6 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
let images = Tensor::cat(images, 0);
|
||||
let targets = Tensor::cat(targets, 0);
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
MnistBatch { images, targets }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::data::MNISTBatch;
|
||||
use crate::data::MnistBatch;
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
|
||||
|
@ -73,7 +73,7 @@ impl<B: Backend> Model<B> {
|
|||
self.fc2.forward(x)
|
||||
}
|
||||
|
||||
pub fn forward_classification(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
pub fn forward_classification(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||
let targets = item.targets;
|
||||
let output = self.forward(item.images);
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
|
@ -117,16 +117,16 @@ impl<B: Backend> ConvBlock<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward_classification(item);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward_classification(item)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::data::MNISTBatcher;
|
||||
use crate::data::MnistBatcher;
|
||||
use crate::model::Model;
|
||||
|
||||
use burn::module::Module;
|
||||
|
@ -10,7 +10,7 @@ use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
|
|||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
|
@ -44,19 +44,19 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
B::seed(config.seed);
|
||||
|
||||
// Data
|
||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::train());
|
||||
.build(MnistDataset::train());
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(MNISTDataset::test());
|
||||
.build(MnistDataset::test());
|
||||
|
||||
// Model
|
||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::env::args;
|
|||
use burn::backend::ndarray::NdArray;
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
|
||||
use onnx_inference::mnist::Model;
|
||||
|
@ -34,7 +34,7 @@ fn main() {
|
|||
let model: Model<Backend> = Model::default();
|
||||
|
||||
// Load the MNIST dataset and get an item
|
||||
let dataset = MNISTDataset::test();
|
||||
let dataset = MnistDataset::test();
|
||||
let item = dataset.get(image_index).unwrap();
|
||||
|
||||
// Create a tensor from the image data
|
||||
|
|
|
@ -5,7 +5,7 @@ use burn::backend::ndarray::NdArray;
|
|||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
|
||||
use model::Model;
|
||||
|
@ -42,7 +42,7 @@ fn main() {
|
|||
let model: Model<Backend> = Model::new_with(record);
|
||||
|
||||
// Load the MNIST dataset and get an item
|
||||
let dataset = MNISTDataset::test();
|
||||
let dataset = MnistDataset::test();
|
||||
let item = dataset.get(image_index).unwrap();
|
||||
|
||||
// Create a tensor from the image data
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use crate::dataset::DiabetesBatch;
|
||||
use burn::config::Config;
|
||||
use burn::nn::loss::Reduction::Mean;
|
||||
use burn::nn::ReLU;
|
||||
use burn::nn::Relu;
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{loss::MSELoss, Linear, LinearConfig},
|
||||
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
|
@ -16,7 +16,7 @@ use burn::{
|
|||
pub struct RegressionModel<B: Backend> {
|
||||
input_layer: Linear<B>,
|
||||
output_layer: Linear<B>,
|
||||
activation: ReLU,
|
||||
activation: Relu,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
|
@ -39,7 +39,7 @@ impl RegressionModelConfig {
|
|||
RegressionModel {
|
||||
input_layer,
|
||||
output_layer,
|
||||
activation: ReLU::new(),
|
||||
activation: Relu::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ impl<B: Backend> RegressionModel<B> {
|
|||
let targets: Tensor<B, 2> = item.targets.unsqueeze();
|
||||
let output: Tensor<B, 2> = self.forward(item.inputs);
|
||||
|
||||
let loss = MSELoss::new().forward(output.clone(), targets.clone(), Mean);
|
||||
let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);
|
||||
|
||||
RegressionOutput {
|
||||
loss,
|
||||
|
|
|
@ -19,7 +19,7 @@ use burn::{
|
|||
record::{CompactRecorder, Recorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric},
|
||||
metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
|
||||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
|
@ -91,8 +91,8 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|||
|
||||
// Initialize learner
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train(CUDAMetric::new())
|
||||
.metric_valid(CUDAMetric::new())
|
||||
.metric_train(CudaMetric::new())
|
||||
.metric_valid(CudaMetric::new())
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
|
|
|
@ -13,7 +13,7 @@ use burn::{
|
|||
record::{CompactRecorder, DefaultRecorder, Recorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric},
|
||||
metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
|
||||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
|
@ -68,8 +68,8 @@ pub fn train<B: AutodiffBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
.init();
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train(CUDAMetric::new())
|
||||
.metric_valid(CUDAMetric::new())
|
||||
.metric_train(CudaMetric::new())
|
||||
.metric_valid(CudaMetric::new())
|
||||
.metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||
.metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||
.metric_train(LossMetric::new())
|
||||
|
|
Loading…
Reference in New Issue