mirror of https://github.com/tracel-ai/burn.git
Make all struct CamelCase (#1316)
This commit is contained in:
parent
dfb739c89a
commit
44266d5fd4
|
@ -194,7 +194,7 @@ impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
|
||||||
```
|
```
|
||||||
|
|
||||||
Subsequently, we'll go into implementing our custom backend trait for the WGPU backend.
|
Subsequently, we'll go into implementing our custom backend trait for the WGPU backend.
|
||||||
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
|
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
|
||||||
we implement the trait for the raw `WgpuBackend` type.
|
we implement the trait for the raw `WgpuBackend` type.
|
||||||
|
|
||||||
```rust, ignore
|
```rust, ignore
|
||||||
|
|
|
@ -16,15 +16,15 @@ at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/exa
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
use burn::{
|
use burn::{
|
||||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct MNISTBatcher<B: Backend> {
|
pub struct MnistBatcher<B: Backend> {
|
||||||
device: B::Device,
|
device: B::Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MNISTBatcher<B> {
|
impl<B: Backend> MnistBatcher<B> {
|
||||||
pub fn new(device: B::Device) -> Self {
|
pub fn new(device: B::Device) -> Self {
|
||||||
Self { device }
|
Self { device }
|
||||||
}
|
}
|
||||||
|
@ -42,13 +42,13 @@ Next, we need to actually implement the batching logic.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MNISTBatch<B: Backend> {
|
pub struct MnistBatch<B: Backend> {
|
||||||
pub images: Tensor<B, 3>,
|
pub images: Tensor<B, 3>,
|
||||||
pub targets: Tensor<B, 1, Int>,
|
pub targets: Tensor<B, 1, Int>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||||
let images = items
|
let images = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| Data::<f32, 2>::from(item.image))
|
.map(|item| Data::<f32, 2>::from(item.image))
|
||||||
|
@ -71,7 +71,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||||
let images = Tensor::cat(images, 0).to_device(&self.device);
|
let images = Tensor::cat(images, 0).to_device(&self.device);
|
||||||
let targets = Tensor::cat(targets, 0).to_device(&self.device);
|
let targets = Tensor::cat(targets, 0).to_device(&self.device);
|
||||||
|
|
||||||
MNISTBatch { images, targets }
|
MnistBatch { images, targets }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -81,7 +81,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||||
|
|
||||||
The iterator pattern allows you to perform some tasks on a sequence of items in turn.
|
The iterator pattern allows you to perform some tasks on a sequence of items in turn.
|
||||||
|
|
||||||
In this example, an iterator is created over the `MNISTItem`s in the vector `items` by calling the
|
In this example, an iterator is created over the `MnistItem`s in the vector `items` by calling the
|
||||||
`iter` method.
|
`iter` method.
|
||||||
|
|
||||||
_Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by
|
_Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by
|
||||||
|
@ -100,7 +100,7 @@ If we go back to the example, we can break down and comment the expression used
|
||||||
images.
|
images.
|
||||||
|
|
||||||
```rust, ignore
|
```rust, ignore
|
||||||
let images = items // take items Vec<MNISTItem>
|
let images = items // take items Vec<MnistItem>
|
||||||
.iter() // create an iterator over it
|
.iter() // create an iterator over it
|
||||||
.map(|item| Data::<f32, 2>::from(item.image)) // for each item, convert the image to float32 data struct
|
.map(|item| Data::<f32, 2>::from(item.image)) // for each item, convert the image to float32 data struct
|
||||||
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device
|
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device
|
||||||
|
@ -115,8 +115,8 @@ Book.
|
||||||
|
|
||||||
</details><br>
|
</details><br>
|
||||||
|
|
||||||
In the previous example, we implement the `Batcher` trait with a list of `MNISTItem` as input and a
|
In the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a
|
||||||
single `MNISTBatch` as output. The batch contains the images in the form of a 3D tensor, along with
|
single `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with
|
||||||
a targets tensor that contains the indexes of the correct digit class. The first step is to parse
|
a targets tensor that contains the indexes of the correct digit class. The first step is to parse
|
||||||
the image array into a `Data` struct. Burn provides the `Data` struct to encapsulate tensor storage
|
the image array into a `Data` struct. Burn provides the `Data` struct to encapsulate tensor storage
|
||||||
information without being specific for a backend. When creating a tensor from data, we often need to
|
information without being specific for a backend. When creating a tensor from data, we often need to
|
||||||
|
|
|
@ -16,7 +16,7 @@ impl ModelConfig {
|
||||||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
|
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
|
||||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
|
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
|
||||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
|
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
|
||||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
|
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
|
||||||
.init_with(record.linear2),
|
.init_with(record.linear2),
|
||||||
|
@ -33,7 +33,7 @@ manually. Everything is validated when loading the model with the record.
|
||||||
Now let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model.
|
Now let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||||
.expect("Config should exist for the model");
|
.expect("Config should exist for the model");
|
||||||
let record = CompactRecorder::new()
|
let record = CompactRecorder::new()
|
||||||
|
@ -43,7 +43,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
||||||
let model = config.model.init_with::<B>(record);
|
let model = config.model.init_with::<B>(record);
|
||||||
|
|
||||||
let label = item.label;
|
let label = item.label;
|
||||||
let batcher = MNISTBatcher::new(device);
|
let batcher = MnistBatcher::new(device);
|
||||||
let batch = batcher.batch(vec![item]);
|
let batch = batcher.batch(vec![item]);
|
||||||
let output = model.forward(batch.images);
|
let output = model.forward(batch.images);
|
||||||
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
||||||
|
@ -56,6 +56,6 @@ The first step is to load the configuration of the training to fetch the correct
|
||||||
configuration. Then we can fetch the record using the same recorder as we used during training.
|
configuration. Then we can fetch the record using the same recorder as we used during training.
|
||||||
Finally we can init the model with the configuration and the record before sending it to the wanted
|
Finally we can init the model with the configuration and the record before sending it to the wanted
|
||||||
device for inference. For simplicity we can use the same batcher used during the training to pass
|
device for inference. For simplicity we can use the same batcher used during the training to pass
|
||||||
from a MNISTItem to a tensor.
|
from a MnistItem to a tensor.
|
||||||
|
|
||||||
By running the infer function, you should see the predictions of your model!
|
By running the infer function, you should see the predictions of your model!
|
||||||
|
|
|
@ -35,7 +35,7 @@ use burn::{
|
||||||
nn::{
|
nn::{
|
||||||
conv::{Conv2d, Conv2dConfig},
|
conv::{Conv2d, Conv2dConfig},
|
||||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||||
Dropout, DropoutConfig, Linear, LinearConfig, ReLU,
|
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||||
},
|
},
|
||||||
tensor::{backend::Backend, Tensor},
|
tensor::{backend::Backend, Tensor},
|
||||||
};
|
};
|
||||||
|
@ -48,7 +48,7 @@ pub struct Model<B: Backend> {
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
linear1: Linear<B>,
|
linear1: Linear<B>,
|
||||||
linear2: Linear<B>,
|
linear2: Linear<B>,
|
||||||
activation: ReLU,
|
activation: Relu,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ There are two major things going on in this code sample.
|
||||||
pub struct MyCustomModule<B: Backend> {
|
pub struct MyCustomModule<B: Backend> {
|
||||||
linear1: Linear<B>,
|
linear1: Linear<B>,
|
||||||
linear2: Linear<B>,
|
linear2: Linear<B>,
|
||||||
activation: ReLU,
|
activation: Relu,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ impl ModelConfig {
|
||||||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
||||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
||||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
||||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
||||||
dropout: DropoutConfig::new(self.dropout).init(),
|
dropout: DropoutConfig::new(self.dropout).init(),
|
||||||
|
|
|
@ -43,23 +43,23 @@ Moving forward, we will proceed with the implementation of both the training and
|
||||||
for our model.
|
for our model.
|
||||||
|
|
||||||
```rust , ignore
|
```rust , ignore
|
||||||
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||||
let item = self.forward_classification(batch.images, batch.targets);
|
let item = self.forward_classification(batch.images, batch.targets);
|
||||||
|
|
||||||
TrainOutput::new(self, item.loss.backward(), item)
|
TrainOutput::new(self, item.loss.backward(), item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {
|
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||||
self.forward_classification(batch.images, batch.targets)
|
self.forward_classification(batch.images, batch.targets)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Here we define the input and output types as generic arguments in the `TrainStep` and `ValidStep`.
|
Here we define the input and output types as generic arguments in the `TrainStep` and `ValidStep`.
|
||||||
We will call them `MNISTBatch` and `ClassificationOutput`. In the training step, the computation of
|
We will call them `MnistBatch` and `ClassificationOutput`. In the training step, the computation of
|
||||||
gradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note
|
gradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note
|
||||||
that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather
|
that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather
|
||||||
returned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a
|
returned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a
|
||||||
|
@ -81,8 +81,8 @@ which is generic over the `Backend` trait as has been covered before. These trai
|
||||||
`burn::train` and define a common `step` method that should be implemented for all structs. Since
|
`burn::train` and define a common `step` method that should be implemented for all structs. Since
|
||||||
the trait is generic over the input and output types, the trait implementation must specify the
|
the trait is generic over the input and output types, the trait implementation must specify the
|
||||||
concrete types used. This is where the additional type constraints appear
|
concrete types used. This is where the additional type constraints appear
|
||||||
`<MNISTBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the
|
`<MnistBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the
|
||||||
batch is `MNISTBatch`, and the output of the forward pass is `ClassificationOutput`. The `step`
|
batch is `MnistBatch`, and the output of the forward pass is `ClassificationOutput`. The `step`
|
||||||
method signature matches the concrete input and output types.
|
method signature matches the concrete input and output types.
|
||||||
|
|
||||||
For more details specific to constraints on generic types when defining methods, take a look at
|
For more details specific to constraints on generic types when defining methods, take a look at
|
||||||
|
@ -118,20 +118,20 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
||||||
|
|
||||||
B::seed(config.seed);
|
B::seed(config.seed);
|
||||||
|
|
||||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::train());
|
.build(MnistDataset::train());
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::test());
|
.build(MnistDataset::test());
|
||||||
|
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train_numeric(AccuracyMetric::new())
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
|
|
|
@ -160,4 +160,4 @@ Burn comes with built-in modules that you can use to build your own modules.
|
||||||
| Burn API | PyTorch Equivalent |
|
| Burn API | PyTorch Equivalent |
|
||||||
| ------------------ | --------------------- |
|
| ------------------ | --------------------- |
|
||||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||||
| `MSELoss` | `nn.MSELoss` |
|
| `MseLoss` | `nn.MSELoss` |
|
||||||
|
|
|
@ -40,21 +40,21 @@ pub fn run<B: AutodiffBackend>(device: &B::Device) {
|
||||||
let mut optim = config.optimizer.init();
|
let mut optim = config.optimizer.init();
|
||||||
|
|
||||||
// Create the batcher.
|
// Create the batcher.
|
||||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
// Create the dataloaders.
|
// Create the dataloaders.
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::train());
|
.build(MnistDataset::train());
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::test());
|
.build(MnistDataset::test());
|
||||||
|
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
|
@ -140,7 +140,7 @@ Note that after each epoch, we include a validation loop to assess our model's p
|
||||||
previously unseen data. To disable gradient tracking during this validation step, we can invoke
|
previously unseen data. To disable gradient tracking during this validation step, we can invoke
|
||||||
`model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's
|
`model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's
|
||||||
important to emphasize that we've declared our validation batcher to be on the inner backend,
|
important to emphasize that we've declared our validation batcher to be on the inner backend,
|
||||||
specifically `MNISTBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation
|
specifically `MnistBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation
|
||||||
error.
|
error.
|
||||||
|
|
||||||
You can find the code above available as an
|
You can find the code above available as an
|
||||||
|
@ -195,7 +195,7 @@ where
|
||||||
M: AutodiffModule<B>,
|
M: AutodiffModule<B>,
|
||||||
O: Optimizer<M, B>,
|
O: Optimizer<M, B>,
|
||||||
{
|
{
|
||||||
pub fn step(&mut self, _batch: MNISTBatch<B>) {
|
pub fn step(&mut self, _batch: MnistBatch<B>) {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -214,7 +214,7 @@ the backend and add your trait constraint within its definition:
|
||||||
```rust, ignore
|
```rust, ignore
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
impl<M, O> Learner2<M, O> {
|
impl<M, O> Learner2<M, O> {
|
||||||
pub fn step<B: AutodiffBackend>(&mut self, _batch: MNISTBatch<B>)
|
pub fn step<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
|
||||||
where
|
where
|
||||||
B: AutodiffBackend,
|
B: AutodiffBackend,
|
||||||
M: AutodiffModule<B>,
|
M: AutodiffModule<B>,
|
||||||
|
|
|
@ -44,7 +44,7 @@ model definition as a simple example.
|
||||||
pub struct Model<B: Backend> {
|
pub struct Model<B: Backend> {
|
||||||
linear_in: Linear<B>,
|
linear_in: Linear<B>,
|
||||||
linear_out: Linear<B>,
|
linear_out: Linear<B>,
|
||||||
activation: ReLU,
|
activation: Relu,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ impl<B: Backend> Model<B> {
|
||||||
Model {
|
Model {
|
||||||
linear_in: LinearConfig::new(10, 64).init_with(record.linear_in),
|
linear_in: LinearConfig::new(10, 64).init_with(record.linear_in),
|
||||||
linear_out: LinearConfig::new(64, 2).init_with(record.linear_out),
|
linear_out: LinearConfig::new(64, 2).init_with(record.linear_out),
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ impl<B: Backend> Model<B> {
|
||||||
Model {
|
Model {
|
||||||
linear_in: l1,
|
linear_in: l1,
|
||||||
linear_out: l2,
|
linear_out: l2,
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,17 +5,17 @@ use burn_tensor::{backend::Backend, Tensor};
|
||||||
|
|
||||||
/// Calculate the mean squared error loss from the input logits and the targets.
|
/// Calculate the mean squared error loss from the input logits and the targets.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MSELoss<B: Backend> {
|
pub struct MseLoss<B: Backend> {
|
||||||
backend: PhantomData<B>,
|
backend: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Default for MSELoss<B> {
|
impl<B: Backend> Default for MseLoss<B> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MSELoss<B> {
|
impl<B: Backend> MseLoss<B> {
|
||||||
/// Create the criterion.
|
/// Create the criterion.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -67,7 +67,7 @@ mod tests {
|
||||||
let targets =
|
let targets =
|
||||||
Tensor::<TestBackend, 2>::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]]), &device);
|
Tensor::<TestBackend, 2>::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]]), &device);
|
||||||
|
|
||||||
let mse = MSELoss::new();
|
let mse = MseLoss::new();
|
||||||
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
|
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
|
||||||
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
|
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
|
||||||
let loss_sum = mse.forward(logits, targets, Reduction::Sum);
|
let loss_sum = mse.forward(logits, targets, Reduction::Sum);
|
||||||
|
|
|
@ -8,9 +8,9 @@ use crate::tensor::Tensor;
|
||||||
///
|
///
|
||||||
/// `y = max(0, x)`
|
/// `y = max(0, x)`
|
||||||
#[derive(Module, Clone, Debug, Default)]
|
#[derive(Module, Clone, Debug, Default)]
|
||||||
pub struct ReLU {}
|
pub struct Relu {}
|
||||||
|
|
||||||
impl ReLU {
|
impl Relu {
|
||||||
/// Create the module.
|
/// Create the module.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {}
|
Self {}
|
||||||
|
|
|
@ -27,14 +27,14 @@ pub struct AdaGradConfig {
|
||||||
|
|
||||||
/// AdaGrad optimizer
|
/// AdaGrad optimizer
|
||||||
pub struct AdaGrad<B: Backend> {
|
pub struct AdaGrad<B: Backend> {
|
||||||
lr_decay: LRDecay,
|
lr_decay: LrDecay,
|
||||||
weight_decay: Option<WeightDecay<B>>,
|
weight_decay: Option<WeightDecay<B>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AdaGrad state.
|
/// AdaGrad state.
|
||||||
#[derive(Record, Clone, new)]
|
#[derive(Record, Clone, new)]
|
||||||
pub struct AdaGradState<B: Backend, const D: usize> {
|
pub struct AdaGradState<B: Backend, const D: usize> {
|
||||||
lr_decay: LRDecayState<B, D>,
|
lr_decay: LrDecayState<B, D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> SimpleOptimizer<B> for AdaGrad<B> {
|
impl<B: Backend> SimpleOptimizer<B> for AdaGrad<B> {
|
||||||
|
@ -81,7 +81,7 @@ impl AdaGradConfig {
|
||||||
/// Returns an optimizer that can be used to optimize a module.
|
/// Returns an optimizer that can be used to optimize a module.
|
||||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> impl Optimizer<M, B> {
|
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> impl Optimizer<M, B> {
|
||||||
let optim = AdaGrad {
|
let optim = AdaGrad {
|
||||||
lr_decay: LRDecay {
|
lr_decay: LrDecay {
|
||||||
lr_decay: self.lr_decay,
|
lr_decay: self.lr_decay,
|
||||||
epsilon: self.epsilon,
|
epsilon: self.epsilon,
|
||||||
},
|
},
|
||||||
|
@ -98,29 +98,29 @@ impl AdaGradConfig {
|
||||||
|
|
||||||
/// Learning rate decay state (also includes sum state).
|
/// Learning rate decay state (also includes sum state).
|
||||||
#[derive(Record, new, Clone)]
|
#[derive(Record, new, Clone)]
|
||||||
pub struct LRDecayState<B: Backend, const D: usize> {
|
pub struct LrDecayState<B: Backend, const D: usize> {
|
||||||
time: usize,
|
time: usize,
|
||||||
sum: Tensor<B, D>,
|
sum: Tensor<B, D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LRDecay {
|
struct LrDecay {
|
||||||
lr_decay: f64,
|
lr_decay: f64,
|
||||||
epsilon: f32,
|
epsilon: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LRDecay {
|
impl LrDecay {
|
||||||
pub fn transform<B: Backend, const D: usize>(
|
pub fn transform<B: Backend, const D: usize>(
|
||||||
&self,
|
&self,
|
||||||
grad: Tensor<B, D>,
|
grad: Tensor<B, D>,
|
||||||
lr: LearningRate,
|
lr: LearningRate,
|
||||||
lr_decay_state: Option<LRDecayState<B, D>>,
|
lr_decay_state: Option<LrDecayState<B, D>>,
|
||||||
) -> (Tensor<B, D>, LRDecayState<B, D>) {
|
) -> (Tensor<B, D>, LrDecayState<B, D>) {
|
||||||
let state = if let Some(mut state) = lr_decay_state {
|
let state = if let Some(mut state) = lr_decay_state {
|
||||||
state.sum = state.sum.add(grad.clone().powf_scalar(2.));
|
state.sum = state.sum.add(grad.clone().powf_scalar(2.));
|
||||||
state.time += 1;
|
state.time += 1;
|
||||||
state
|
state
|
||||||
} else {
|
} else {
|
||||||
LRDecayState::new(1, grad.clone().powf_scalar(2.))
|
LrDecayState::new(1, grad.clone().powf_scalar(2.))
|
||||||
};
|
};
|
||||||
|
|
||||||
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
||||||
|
@ -133,7 +133,7 @@ impl LRDecay {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, const D: usize> LRDecayState<B, D> {
|
impl<B: Backend, const D: usize> LrDecayState<B, D> {
|
||||||
/// Move state to device.
|
/// Move state to device.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -278,7 +278,7 @@ mod tests {
|
||||||
{
|
{
|
||||||
let config = AdaGradConfig::new();
|
let config = AdaGradConfig::new();
|
||||||
AdaGrad {
|
AdaGrad {
|
||||||
lr_decay: LRDecay {
|
lr_decay: LrDecay {
|
||||||
lr_decay: config.lr_decay,
|
lr_decay: config.lr_decay,
|
||||||
epsilon: config.epsilon,
|
epsilon: config.epsilon,
|
||||||
},
|
},
|
||||||
|
|
|
@ -12,19 +12,19 @@ use crate::optim::adaptor::OptimizerAdaptor;
|
||||||
use crate::tensor::{backend::AutodiffBackend, Tensor};
|
use crate::tensor::{backend::AutodiffBackend, Tensor};
|
||||||
use burn_tensor::backend::Backend;
|
use burn_tensor::backend::Backend;
|
||||||
|
|
||||||
/// Configuration to create the [RMSProp](RMSProp) optimizer.
|
/// Configuration to create the [RmsProp](RmsProp) optimizer.
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
pub struct RMSPropConfig {
|
pub struct RmsPropConfig {
|
||||||
/// Smoothing constant.
|
/// Smoothing constant.
|
||||||
#[config(default = 0.99)]
|
#[config(default = 0.99)]
|
||||||
alpha: f32,
|
alpha: f32,
|
||||||
/// momentum for RMSProp.
|
/// momentum for RmsProp.
|
||||||
#[config(default = 0.9)]
|
#[config(default = 0.9)]
|
||||||
momentum: f32,
|
momentum: f32,
|
||||||
/// A value required for numerical stability.
|
/// A value required for numerical stability.
|
||||||
#[config(default = 1e-5)]
|
#[config(default = 1e-5)]
|
||||||
epsilon: f32,
|
epsilon: f32,
|
||||||
/// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
|
/// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
|
||||||
#[config(default = false)]
|
#[config(default = false)]
|
||||||
centered: bool,
|
centered: bool,
|
||||||
/// [Weight decay](WeightDecayConfig) config.
|
/// [Weight decay](WeightDecayConfig) config.
|
||||||
|
@ -33,22 +33,22 @@ pub struct RMSPropConfig {
|
||||||
grad_clipping: Option<GradientClippingConfig>,
|
grad_clipping: Option<GradientClippingConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RMSPropConfig {
|
impl RmsPropConfig {
|
||||||
/// Initialize RMSProp optimizer.
|
/// Initialize RmsProp optimizer.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// Returns an optimizer that can be used to optimize a module.
|
/// Returns an optimizer that can be used to optimize a module.
|
||||||
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||||
&self,
|
&self,
|
||||||
) -> OptimizerAdaptor<RMSProp<B::InnerBackend>, M, B> {
|
) -> OptimizerAdaptor<RmsProp<B::InnerBackend>, M, B> {
|
||||||
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
|
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
|
||||||
|
|
||||||
let mut optim = OptimizerAdaptor::from(RMSProp {
|
let mut optim = OptimizerAdaptor::from(RmsProp {
|
||||||
alpha: self.alpha,
|
alpha: self.alpha,
|
||||||
centered: self.centered,
|
centered: self.centered,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
momentum: RMSPropMomentum {
|
momentum: RmsPropMomentum {
|
||||||
momentum: self.momentum,
|
momentum: self.momentum,
|
||||||
epsilon: self.epsilon,
|
epsilon: self.epsilon,
|
||||||
},
|
},
|
||||||
|
@ -63,18 +63,18 @@ impl RMSPropConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Optimizer that implements stochastic gradient descent with momentum.
|
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||||
/// The optimizer can be configured with [RMSPropConfig](RMSPropConfig).
|
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
|
||||||
pub struct RMSProp<B: Backend> {
|
pub struct RmsProp<B: Backend> {
|
||||||
alpha: f32,
|
alpha: f32,
|
||||||
// epsilon: f32,
|
// epsilon: f32,
|
||||||
centered: bool,
|
centered: bool,
|
||||||
// momentum: Option<Momentum<B>>,
|
// momentum: Option<Momentum<B>>,
|
||||||
momentum: RMSPropMomentum,
|
momentum: RmsPropMomentum,
|
||||||
weight_decay: Option<WeightDecay<B>>,
|
weight_decay: Option<WeightDecay<B>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
impl<B: Backend> SimpleOptimizer<B> for RmsProp<B> {
|
||||||
type State<const D: usize> = RMSPropState<B, D>;
|
type State<const D: usize> = RmsPropState<B, D>;
|
||||||
|
|
||||||
fn step<const D: usize>(
|
fn step<const D: usize>(
|
||||||
&self,
|
&self,
|
||||||
|
@ -117,7 +117,7 @@ impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
||||||
.transform(grad, state_centered, state_momentum);
|
.transform(grad, state_centered, state_momentum);
|
||||||
|
|
||||||
// transition state
|
// transition state
|
||||||
let state = RMSPropState::new(state_square_avg, state_centered, state_momentum);
|
let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
|
||||||
|
|
||||||
// tensor param transform
|
// tensor param transform
|
||||||
let delta = grad.mul_scalar(lr);
|
let delta = grad.mul_scalar(lr);
|
||||||
|
@ -135,12 +135,12 @@ impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// State of [RMSProp](RMSProp)
|
/// State of [RmsProp](RmsProp)
|
||||||
#[derive(Record, Clone, new)]
|
#[derive(Record, Clone, new)]
|
||||||
pub struct RMSPropState<B: Backend, const D: usize> {
|
pub struct RmsPropState<B: Backend, const D: usize> {
|
||||||
square_avg: SquareAvgState<B, D>,
|
square_avg: SquareAvgState<B, D>,
|
||||||
centered: CenteredState<B, D>,
|
centered: CenteredState<B, D>,
|
||||||
momentum: Option<RMSPropMomentumState<B, D>>,
|
momentum: Option<RmsPropMomentumState<B, D>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
|
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
|
||||||
|
@ -249,24 +249,24 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [RMSPropMomentum](RMSPropMomentum) is to store config status for optimizer.
|
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
|
||||||
/// (, which is stored in [optimizer](RMSProp) itself and not passed in during `step()` calculation)
|
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
|
||||||
pub struct RMSPropMomentum {
|
pub struct RmsPropMomentum {
|
||||||
momentum: f32,
|
momentum: f32,
|
||||||
epsilon: f32,
|
epsilon: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RMSPropMomentum {
|
impl RmsPropMomentum {
|
||||||
/// transform [grad](Tensor) and [RMSPropMomentumState] to the next step
|
/// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
|
||||||
fn transform<B: Backend, const D: usize>(
|
fn transform<B: Backend, const D: usize>(
|
||||||
&self,
|
&self,
|
||||||
grad: Tensor<B, D>,
|
grad: Tensor<B, D>,
|
||||||
centered_state: CenteredState<B, D>,
|
centered_state: CenteredState<B, D>,
|
||||||
momentum_state: Option<RMSPropMomentumState<B, D>>,
|
momentum_state: Option<RmsPropMomentumState<B, D>>,
|
||||||
) -> (
|
) -> (
|
||||||
Tensor<B, D>,
|
Tensor<B, D>,
|
||||||
CenteredState<B, D>,
|
CenteredState<B, D>,
|
||||||
Option<RMSPropMomentumState<B, D>>,
|
Option<RmsPropMomentumState<B, D>>,
|
||||||
) {
|
) {
|
||||||
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||||
|
|
||||||
|
@ -278,7 +278,7 @@ impl RMSPropMomentum {
|
||||||
(
|
(
|
||||||
buf.clone(),
|
buf.clone(),
|
||||||
centered_state,
|
centered_state,
|
||||||
Some(RMSPropMomentumState { buf }),
|
Some(RmsPropMomentumState { buf }),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
(grad, centered_state, None)
|
(grad, centered_state, None)
|
||||||
|
@ -286,13 +286,13 @@ impl RMSPropMomentum {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [RMSPropMomentumState](RMSPropMomentumState) is to store and pass optimizer step params.
|
/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
|
||||||
#[derive(Record, Clone, new)]
|
#[derive(Record, Clone, new)]
|
||||||
pub struct RMSPropMomentumState<B: Backend, const D: usize> {
|
pub struct RmsPropMomentumState<B: Backend, const D: usize> {
|
||||||
buf: Tensor<B, D>,
|
buf: Tensor<B, D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, const D: usize> RMSPropMomentumState<B, D> {
|
impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
|
||||||
/// Moves the state to a device.
|
/// Moves the state to a device.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -378,7 +378,7 @@ mod tests {
|
||||||
)
|
)
|
||||||
.require_grad();
|
.require_grad();
|
||||||
|
|
||||||
let mut optimizer = RMSPropConfig::new()
|
let mut optimizer = RmsPropConfig::new()
|
||||||
.with_alpha(0.99)
|
.with_alpha(0.99)
|
||||||
.with_epsilon(1e-8)
|
.with_epsilon(1e-8)
|
||||||
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||||
|
@ -453,7 +453,7 @@ mod tests {
|
||||||
)
|
)
|
||||||
.require_grad();
|
.require_grad();
|
||||||
|
|
||||||
let mut optimizer = RMSPropConfig::new()
|
let mut optimizer = RmsPropConfig::new()
|
||||||
.with_alpha(0.99)
|
.with_alpha(0.99)
|
||||||
.with_epsilon(1e-8)
|
.with_epsilon(1e-8)
|
||||||
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||||
|
@ -529,9 +529,9 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_rmsprop(
|
fn create_rmsprop(
|
||||||
) -> OptimizerAdaptor<RMSProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
|
) -> OptimizerAdaptor<RmsProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
|
||||||
{
|
{
|
||||||
RMSPropConfig {
|
RmsPropConfig {
|
||||||
alpha: 0.99,
|
alpha: 0.99,
|
||||||
epsilon: 1e-9,
|
epsilon: 1e-9,
|
||||||
centered: false,
|
centered: false,
|
||||||
|
|
|
@ -49,12 +49,12 @@ pub enum ImporterError {
|
||||||
/// use serde::{Deserialize, Serialize};
|
/// use serde::{Deserialize, Serialize};
|
||||||
///
|
///
|
||||||
/// #[derive(Deserialize, Debug, Clone)]
|
/// #[derive(Deserialize, Debug, Clone)]
|
||||||
/// struct MNISTItemRaw {
|
/// struct MnistItemRaw {
|
||||||
/// pub image_bytes: Vec<u8>,
|
/// pub image_bytes: Vec<u8>,
|
||||||
/// pub label: usize,
|
/// pub label: usize,
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// let train_ds:SqliteDataset<MNISTItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
/// let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||||
/// .dataset("train")
|
/// .dataset("train")
|
||||||
/// .unwrap();
|
/// .unwrap();
|
||||||
pub struct HuggingfaceDatasetLoader {
|
pub struct HuggingfaceDatasetLoader {
|
||||||
|
|
|
@ -24,7 +24,7 @@ const HEIGHT: usize = 28;
|
||||||
|
|
||||||
/// MNIST item.
|
/// MNIST item.
|
||||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||||
pub struct MNISTItem {
|
pub struct MnistItem {
|
||||||
/// Image as a 2D array of floats.
|
/// Image as a 2D array of floats.
|
||||||
pub image: [[f32; WIDTH]; HEIGHT],
|
pub image: [[f32; WIDTH]; HEIGHT],
|
||||||
|
|
||||||
|
@ -33,16 +33,16 @@ pub struct MNISTItem {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
struct MNISTItemRaw {
|
struct MnistItemRaw {
|
||||||
pub image_bytes: Vec<u8>,
|
pub image_bytes: Vec<u8>,
|
||||||
pub label: u8,
|
pub label: u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BytesToImage;
|
struct BytesToImage;
|
||||||
|
|
||||||
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
|
||||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||||
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
|
fn map(&self, item: &MnistItemRaw) -> MnistItem {
|
||||||
// Ensure the image dimensions are correct.
|
// Ensure the image dimensions are correct.
|
||||||
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
||||||
|
|
||||||
|
@ -54,25 +54,25 @@ impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
||||||
image_array[y][x] = *pixel as f32;
|
image_array[y][x] = *pixel as f32;
|
||||||
}
|
}
|
||||||
|
|
||||||
MNISTItem {
|
MnistItem {
|
||||||
image: image_array,
|
image: image_array,
|
||||||
label: item.label,
|
label: item.label,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type MappedDataset = MapperDataset<InMemDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
|
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
|
||||||
|
|
||||||
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
||||||
/// images per class. There are 60,000 training images and 10,000 test images.
|
/// images per class. There are 60,000 training images and 10,000 test images.
|
||||||
///
|
///
|
||||||
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
||||||
pub struct MNISTDataset {
|
pub struct MnistDataset {
|
||||||
dataset: MappedDataset,
|
dataset: MappedDataset,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Dataset<MNISTItem> for MNISTDataset {
|
impl Dataset<MnistItem> for MnistDataset {
|
||||||
fn get(&self, index: usize) -> Option<MNISTItem> {
|
fn get(&self, index: usize) -> Option<MnistItem> {
|
||||||
self.dataset.get(index)
|
self.dataset.get(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ impl Dataset<MNISTItem> for MNISTDataset {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MNISTDataset {
|
impl MnistDataset {
|
||||||
/// Creates a new train dataset.
|
/// Creates a new train dataset.
|
||||||
pub fn train() -> Self {
|
pub fn train() -> Self {
|
||||||
Self::new("train")
|
Self::new("train")
|
||||||
|
@ -94,19 +94,19 @@ impl MNISTDataset {
|
||||||
|
|
||||||
fn new(split: &str) -> Self {
|
fn new(split: &str) -> Self {
|
||||||
// Download dataset
|
// Download dataset
|
||||||
let root = MNISTDataset::download(split);
|
let root = MnistDataset::download(split);
|
||||||
|
|
||||||
// MNIST is tiny so we can load it in-memory
|
// MNIST is tiny so we can load it in-memory
|
||||||
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
|
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
|
||||||
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
|
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
|
||||||
let images = MNISTDataset::read_images(&root, split);
|
let images = MnistDataset::read_images(&root, split);
|
||||||
let labels = MNISTDataset::read_labels(&root, split);
|
let labels = MnistDataset::read_labels(&root, split);
|
||||||
|
|
||||||
// Collect as vector of MNISTItemRaw
|
// Collect as vector of MnistItemRaw
|
||||||
let items: Vec<_> = images
|
let items: Vec<_> = images
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(labels)
|
.zip(labels)
|
||||||
.map(|(image_bytes, label)| MNISTItemRaw { image_bytes, label })
|
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let dataset = InMemDataset::new(items);
|
let dataset = InMemDataset::new(items);
|
||||||
|
@ -132,12 +132,12 @@ impl MNISTDataset {
|
||||||
// Download split files
|
// Download split files
|
||||||
match split {
|
match split {
|
||||||
"train" => {
|
"train" => {
|
||||||
MNISTDataset::download_file(TRAIN_IMAGES, &split_dir);
|
MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||||
MNISTDataset::download_file(TRAIN_LABELS, &split_dir);
|
MnistDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||||
}
|
}
|
||||||
"test" => {
|
"test" => {
|
||||||
MNISTDataset::download_file(TEST_IMAGES, &split_dir);
|
MnistDataset::download_file(TEST_IMAGES, &split_dir);
|
||||||
MNISTDataset::download_file(TEST_LABELS, &split_dir);
|
MnistDataset::download_file(TEST_LABELS, &split_dir);
|
||||||
}
|
}
|
||||||
_ => panic!("Invalid split specified {}", split),
|
_ => panic!("Invalid split specified {}", split),
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use burn::{
|
use burn::{
|
||||||
module::Module,
|
module::Module,
|
||||||
nn::{Linear, LinearConfig, ReLU},
|
nn::{Linear, LinearConfig, Relu},
|
||||||
tensor::{backend::Backend, Tensor},
|
tensor::{backend::Backend, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ use burn::{
|
||||||
pub struct Net<B: Backend> {
|
pub struct Net<B: Backend> {
|
||||||
fc1: Linear<B>,
|
fc1: Linear<B>,
|
||||||
fc2: Linear<B>,
|
fc2: Linear<B>,
|
||||||
relu: ReLU,
|
relu: Relu,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Net<B> {
|
impl<B: Backend> Net<B> {
|
||||||
|
@ -16,7 +16,7 @@ impl<B: Backend> Net<B> {
|
||||||
pub fn new_with(record: NetRecord<B>) -> Self {
|
pub fn new_with(record: NetRecord<B>) -> Self {
|
||||||
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
|
let fc1 = LinearConfig::new(2, 3).init_with(record.fc1);
|
||||||
let fc2 = LinearConfig::new(3, 4).init_with(record.fc2);
|
let fc2 = LinearConfig::new(3, 4).init_with(record.fc2);
|
||||||
let relu = ReLU::default();
|
let relu = Relu::default();
|
||||||
|
|
||||||
Self { fc1, fc2, relu }
|
Self { fc1, fc2, relu }
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ use crate::onnx::{
|
||||||
proto_conversion::convert_node_proto,
|
proto_conversion::convert_node_proto,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor};
|
use super::ir::{ArgType, Argument, Node, NodeType, OnnxGraph, Tensor};
|
||||||
use super::protos::{ModelProto, TensorProto};
|
use super::protos::{ModelProto, TensorProto};
|
||||||
use super::{dim_inference::dim_inference, protos::ValueInfoProto};
|
use super::{dim_inference::dim_inference, protos::ValueInfoProto};
|
||||||
|
|
||||||
|
@ -33,14 +33,14 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * `ONNXGraph` - The graph representation of the onnx file
|
/// * `OnnxGraph` - The graph representation of the onnx file
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// * If the file cannot be opened
|
/// * If the file cannot be opened
|
||||||
/// * If the file cannot be parsed
|
/// * If the file cannot be parsed
|
||||||
/// * If the nodes are not topologically sorted
|
/// * If the nodes are not topologically sorted
|
||||||
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
|
||||||
log::info!("Parsing ONNX file: {}", onnx_path.display());
|
log::info!("Parsing ONNX file: {}", onnx_path.display());
|
||||||
|
|
||||||
// Open the file
|
// Open the file
|
||||||
|
@ -118,7 +118,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
||||||
|
|
||||||
log::info!("Finished parsing ONNX file: {}", onnx_path.display());
|
log::info!("Finished parsing ONNX file: {}", onnx_path.display());
|
||||||
|
|
||||||
ONNXGraph {
|
OnnxGraph {
|
||||||
nodes,
|
nodes,
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
|
|
|
@ -129,7 +129,7 @@ pub enum Data {
|
||||||
|
|
||||||
/// ONNX graph representation
|
/// ONNX graph representation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ONNXGraph {
|
pub struct OnnxGraph {
|
||||||
/// The nodes of the graph.
|
/// The nodes of the graph.
|
||||||
pub nodes: Vec<Node>,
|
pub nodes: Vec<Node>,
|
||||||
|
|
||||||
|
|
|
@ -11,4 +11,4 @@ mod to_burn;
|
||||||
pub use to_burn::*;
|
pub use to_burn::*;
|
||||||
|
|
||||||
pub use from_onnx::parse_onnx;
|
pub use from_onnx::parse_onnx;
|
||||||
pub use ir::ONNXGraph;
|
pub use ir::OnnxGraph;
|
||||||
|
|
|
@ -45,7 +45,7 @@ use crate::{
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
from_onnx::parse_onnx,
|
from_onnx::parse_onnx,
|
||||||
ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph},
|
ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph},
|
||||||
op_configuration::{
|
op_configuration::{
|
||||||
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
|
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
|
||||||
softmax_config,
|
softmax_config,
|
||||||
|
@ -218,7 +218,7 @@ impl ModelGen {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ONNXGraph {
|
impl OnnxGraph {
|
||||||
/// Converts ONNX graph to Burn graph.
|
/// Converts ONNX graph to Burn graph.
|
||||||
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
||||||
let mut graph = BurnGraph::<PS>::default();
|
let mut graph = BurnGraph::<PS>::default();
|
||||||
|
|
|
@ -28,7 +28,7 @@ pub struct MlpConfig {
|
||||||
pub struct Mlp<B: Backend> {
|
pub struct Mlp<B: Backend> {
|
||||||
linears: Vec<nn::Linear<B>>,
|
linears: Vec<nn::Linear<B>>,
|
||||||
dropout: nn::Dropout,
|
dropout: nn::Dropout,
|
||||||
activation: nn::ReLU,
|
activation: nn::Relu,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Mlp<B> {
|
impl<B: Backend> Mlp<B> {
|
||||||
|
@ -43,7 +43,7 @@ impl<B: Backend> Mlp<B> {
|
||||||
Self {
|
Self {
|
||||||
linears,
|
linears,
|
||||||
dropout: nn::DropoutConfig::new(0.3).init(),
|
dropout: nn::DropoutConfig::new(0.3).init(),
|
||||||
activation: nn::ReLU::new(),
|
activation: nn::Relu::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,11 @@ use crate::metric::{Metric, MetricEntry};
|
||||||
use nvml_wrapper::Nvml;
|
use nvml_wrapper::Nvml;
|
||||||
|
|
||||||
/// Track basic cuda infos.
|
/// Track basic cuda infos.
|
||||||
pub struct CUDAMetric {
|
pub struct CudaMetric {
|
||||||
nvml: Option<Nvml>,
|
nvml: Option<Nvml>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CUDAMetric {
|
impl CudaMetric {
|
||||||
/// Creates a new metric for CUDA.
|
/// Creates a new metric for CUDA.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -19,7 +19,7 @@ impl CUDAMetric {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for CUDAMetric {
|
impl Default for CudaMetric {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new()
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ impl<T> Adaptor<()> for T {
|
||||||
fn adapt(&self) {}
|
fn adapt(&self) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Metric for CUDAMetric {
|
impl Metric for CudaMetric {
|
||||||
const NAME: &'static str = "CUDA Stats";
|
const NAME: &'static str = "CUDA Stats";
|
||||||
|
|
||||||
type Input = ();
|
type Input = ();
|
||||||
|
|
|
@ -61,7 +61,7 @@
|
||||||
//! - `audio`: Enables audio datasets (SpeechCommandsDataset)
|
//! - `audio`: Enables audio datasets (SpeechCommandsDataset)
|
||||||
//! - `sqlite`: Stores datasets in SQLite database
|
//! - `sqlite`: Stores datasets in SQLite database
|
||||||
//! - `sqlite_bundled`: Use bundled version of SQLite
|
//! - `sqlite_bundled`: Use bundled version of SQLite
|
||||||
//! - `vision`: Enables vision datasets (MNISTDataset)
|
//! - `vision`: Enables vision datasets (MnistDataset)
|
||||||
//! - Backends
|
//! - Backends
|
||||||
//! - `wgpu`: Makes available the WGPU backend
|
//! - `wgpu`: Makes available the WGPU backend
|
||||||
//! - `candle`: Makes available the Candle backend
|
//! - `candle`: Makes available the Candle backend
|
||||||
|
|
|
@ -3,7 +3,7 @@ use burn::{
|
||||||
nn::{
|
nn::{
|
||||||
conv::{Conv2d, Conv2dConfig},
|
conv::{Conv2d, Conv2dConfig},
|
||||||
pool::{MaxPool2d, MaxPool2dConfig},
|
pool::{MaxPool2d, MaxPool2dConfig},
|
||||||
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, ReLU,
|
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
|
||||||
},
|
},
|
||||||
tensor::{backend::Backend, Device, Tensor},
|
tensor::{backend::Backend, Device, Tensor},
|
||||||
};
|
};
|
||||||
|
@ -23,8 +23,8 @@ use burn::{
|
||||||
// │ maxpool │
|
// │ maxpool │
|
||||||
// └────────────────────┘
|
// └────────────────────┘
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct CNN<B: Backend> {
|
pub struct Cnn<B: Backend> {
|
||||||
activation: ReLU,
|
activation: Relu,
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
pool: MaxPool2d,
|
pool: MaxPool2d,
|
||||||
conv1: Conv2d<B>,
|
conv1: Conv2d<B>,
|
||||||
|
@ -37,7 +37,7 @@ pub struct CNN<B: Backend> {
|
||||||
fc2: Linear<B>,
|
fc2: Linear<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> CNN<B> {
|
impl<B: Backend> Cnn<B> {
|
||||||
pub fn new(num_classes: usize, device: &Device<B>) -> Self {
|
pub fn new(num_classes: usize, device: &Device<B>) -> Self {
|
||||||
let conv1 = Conv2dConfig::new([3, 32], [3, 3])
|
let conv1 = Conv2dConfig::new([3, 32], [3, 3])
|
||||||
.with_padding(PaddingConfig2d::Same)
|
.with_padding(PaddingConfig2d::Same)
|
||||||
|
@ -68,7 +68,7 @@ impl<B: Backend> CNN<B> {
|
||||||
let dropout = DropoutConfig::new(0.3).init();
|
let dropout = DropoutConfig::new(0.3).init();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
dropout,
|
dropout,
|
||||||
pool,
|
pool,
|
||||||
conv1,
|
conv1,
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::time::Instant;
|
||||||
use crate::{
|
use crate::{
|
||||||
data::{ClassificationBatch, ClassificationBatcher},
|
data::{ClassificationBatch, ClassificationBatcher},
|
||||||
dataset::CIFAR10Loader,
|
dataset::CIFAR10Loader,
|
||||||
model::CNN,
|
model::Cnn,
|
||||||
};
|
};
|
||||||
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
|
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
|
||||||
use burn::train::{
|
use burn::train::{
|
||||||
|
@ -26,7 +26,7 @@ use burn::{
|
||||||
const NUM_CLASSES: u8 = 10;
|
const NUM_CLASSES: u8 = 10;
|
||||||
const ARTIFACT_DIR: &str = "/tmp/custom-image-dataset";
|
const ARTIFACT_DIR: &str = "/tmp/custom-image-dataset";
|
||||||
|
|
||||||
impl<B: Backend> CNN<B> {
|
impl<B: Backend> Cnn<B> {
|
||||||
pub fn forward_classification(
|
pub fn forward_classification(
|
||||||
&self,
|
&self,
|
||||||
images: Tensor<B, 4>,
|
images: Tensor<B, 4>,
|
||||||
|
@ -41,7 +41,7 @@ impl<B: Backend> CNN<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<B>> for CNN<B> {
|
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<B>> for Cnn<B> {
|
||||||
fn step(&self, batch: ClassificationBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
fn step(&self, batch: ClassificationBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||||
let item = self.forward_classification(batch.images, batch.targets);
|
let item = self.forward_classification(batch.images, batch.targets);
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, ClassificationOutput<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ValidStep<ClassificationBatch<B>, ClassificationOutput<B>> for CNN<B> {
|
impl<B: Backend> ValidStep<ClassificationBatch<B>, ClassificationOutput<B>> for Cnn<B> {
|
||||||
fn step(&self, batch: ClassificationBatch<B>) -> ClassificationOutput<B> {
|
fn step(&self, batch: ClassificationBatch<B>) -> ClassificationOutput<B> {
|
||||||
self.forward_classification(batch.images, batch.targets)
|
self.forward_classification(batch.images, batch.targets)
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
|
||||||
.devices(vec![device.clone()])
|
.devices(vec![device.clone()])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
.build(
|
.build(
|
||||||
CNN::new(NUM_CLASSES.into(), &device),
|
Cnn::new(NUM_CLASSES.into(), &device),
|
||||||
config.optimizer.init(),
|
config.optimizer.init(),
|
||||||
config.learning_rate,
|
config.learning_rate,
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
use burn::data::dataset::vision::MNISTDataset;
|
use burn::data::dataset::vision::MnistDataset;
|
||||||
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||||
use burn::train::LearnerBuilder;
|
use burn::train::LearnerBuilder;
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
||||||
tensor::backend::AutodiffBackend,
|
tensor::backend::AutodiffBackend,
|
||||||
};
|
};
|
||||||
use guide::{data::MNISTBatcher, model::ModelConfig};
|
use guide::{data::MnistBatcher, model::ModelConfig};
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
pub struct MnistTrainingConfig {
|
pub struct MnistTrainingConfig {
|
||||||
|
@ -52,21 +52,21 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||||
let optim = config.optimizer.init();
|
let optim = config.optimizer.init();
|
||||||
|
|
||||||
// Create the batcher.
|
// Create the batcher.
|
||||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
// Create the dataloaders.
|
// Create the dataloaders.
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::train());
|
.build(MnistDataset::train());
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::test());
|
.build(MnistDataset::test());
|
||||||
|
|
||||||
// artifact dir does not need to be provided when log_to_file is false
|
// artifact dir does not need to be provided when log_to_file is false
|
||||||
let builder = LearnerBuilder::new("")
|
let builder = LearnerBuilder::new("")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use burn::data::dataset::vision::MNISTDataset;
|
use burn::data::dataset::vision::MnistDataset;
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
data::dataloader::DataLoaderBuilder,
|
data::dataloader::DataLoaderBuilder,
|
||||||
|
@ -13,7 +13,7 @@ use burn::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use guide::{
|
use guide::{
|
||||||
data::{MNISTBatch, MNISTBatcher},
|
data::{MnistBatch, MnistBatcher},
|
||||||
model::{Model, ModelConfig},
|
model::{Model, ModelConfig},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -46,21 +46,21 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||||
let mut optim = config.optimizer.init();
|
let mut optim = config.optimizer.init();
|
||||||
|
|
||||||
// Create the batcher.
|
// Create the batcher.
|
||||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
// Create the dataloaders.
|
// Create the dataloaders.
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::train());
|
.build(MnistDataset::train());
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::test());
|
.build(MnistDataset::test());
|
||||||
|
|
||||||
// Iterate over our training and validation loop for X epochs.
|
// Iterate over our training and validation loop for X epochs.
|
||||||
for epoch in 1..config.num_epochs + 1 {
|
for epoch in 1..config.num_epochs + 1 {
|
||||||
|
@ -145,7 +145,7 @@ where
|
||||||
B: AutodiffBackend,
|
B: AutodiffBackend,
|
||||||
O: Optimizer<Model<B>, B>,
|
O: Optimizer<Model<B>, B>,
|
||||||
{
|
{
|
||||||
pub fn step1(&mut self, _batch: MNISTBatch<B>) {
|
pub fn step1(&mut self, _batch: MnistBatch<B>) {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,14 +156,14 @@ where
|
||||||
B: AutodiffBackend,
|
B: AutodiffBackend,
|
||||||
O: Optimizer<Model<B>, B>,
|
O: Optimizer<Model<B>, B>,
|
||||||
{
|
{
|
||||||
pub fn step2(&mut self, _batch: MNISTBatch<B>) {
|
pub fn step2(&mut self, _batch: MnistBatch<B>) {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
impl<M, O> Learner2<M, O> {
|
impl<M, O> Learner2<M, O> {
|
||||||
pub fn step3<B: AutodiffBackend>(&mut self, _batch: MNISTBatch<B>)
|
pub fn step3<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
|
||||||
where
|
where
|
||||||
B: AutodiffBackend,
|
B: AutodiffBackend,
|
||||||
M: AutodiffModule<B>,
|
M: AutodiffModule<B>,
|
||||||
|
|
|
@ -18,7 +18,7 @@ fn main() {
|
||||||
guide::inference::infer::<MyBackend>(
|
guide::inference::infer::<MyBackend>(
|
||||||
artifact_dir,
|
artifact_dir,
|
||||||
device,
|
device,
|
||||||
burn::data::dataset::vision::MNISTDataset::test()
|
burn::data::dataset::vision::MnistDataset::test()
|
||||||
.get(42)
|
.get(42)
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,26 +1,26 @@
|
||||||
use burn::{
|
use burn::{
|
||||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct MNISTBatcher<B: Backend> {
|
pub struct MnistBatcher<B: Backend> {
|
||||||
device: B::Device,
|
device: B::Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MNISTBatcher<B> {
|
impl<B: Backend> MnistBatcher<B> {
|
||||||
pub fn new(device: B::Device) -> Self {
|
pub fn new(device: B::Device) -> Self {
|
||||||
Self { device }
|
Self { device }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MNISTBatch<B: Backend> {
|
pub struct MnistBatch<B: Backend> {
|
||||||
pub images: Tensor<B, 3>,
|
pub images: Tensor<B, 3>,
|
||||||
pub targets: Tensor<B, 1, Int>,
|
pub targets: Tensor<B, 1, Int>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||||
let images = items
|
let images = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| Data::<f32, 2>::from(item.image))
|
.map(|item| Data::<f32, 2>::from(item.image))
|
||||||
|
@ -40,6 +40,6 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||||
let images = Tensor::cat(images, 0);
|
let images = Tensor::cat(images, 0);
|
||||||
let targets = Tensor::cat(targets, 0);
|
let targets = Tensor::cat(targets, 0);
|
||||||
|
|
||||||
MNISTBatch { images, targets }
|
MnistBatch { images, targets }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{data::MNISTBatcher, training::TrainingConfig};
|
use crate::{data::MnistBatcher, training::TrainingConfig};
|
||||||
use burn::data::dataset::vision::MNISTItem;
|
use burn::data::dataset::vision::MnistItem;
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
data::dataloader::batcher::Batcher,
|
data::dataloader::batcher::Batcher,
|
||||||
|
@ -7,7 +7,7 @@ use burn::{
|
||||||
tensor::backend::Backend,
|
tensor::backend::Backend,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
|
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
|
||||||
.expect("Config should exist for the model");
|
.expect("Config should exist for the model");
|
||||||
let record = CompactRecorder::new()
|
let record = CompactRecorder::new()
|
||||||
|
@ -17,7 +17,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
||||||
let model = config.model.init_with::<B>(record);
|
let model = config.model.init_with::<B>(record);
|
||||||
|
|
||||||
let label = item.label;
|
let label = item.label;
|
||||||
let batcher = MNISTBatcher::new(device);
|
let batcher = MnistBatcher::new(device);
|
||||||
let batch = batcher.batch(vec![item]);
|
let batch = batcher.batch(vec![item]);
|
||||||
let output = model.forward(batch.images);
|
let output = model.forward(batch.images);
|
||||||
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
|
||||||
|
|
|
@ -4,7 +4,7 @@ use burn::{
|
||||||
nn::{
|
nn::{
|
||||||
conv::{Conv2d, Conv2dConfig},
|
conv::{Conv2d, Conv2dConfig},
|
||||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||||
Dropout, DropoutConfig, Linear, LinearConfig, ReLU,
|
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||||
},
|
},
|
||||||
tensor::{backend::Backend, Tensor},
|
tensor::{backend::Backend, Tensor},
|
||||||
};
|
};
|
||||||
|
@ -17,7 +17,7 @@ pub struct Model<B: Backend> {
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
linear1: Linear<B>,
|
linear1: Linear<B>,
|
||||||
linear2: Linear<B>,
|
linear2: Linear<B>,
|
||||||
activation: ReLU,
|
activation: Relu,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config, Debug)]
|
#[derive(Config, Debug)]
|
||||||
|
@ -35,7 +35,7 @@ impl ModelConfig {
|
||||||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
|
||||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
|
||||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
|
||||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
|
||||||
dropout: DropoutConfig::new(self.dropout).init(),
|
dropout: DropoutConfig::new(self.dropout).init(),
|
||||||
|
@ -47,7 +47,7 @@ impl ModelConfig {
|
||||||
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
|
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
|
||||||
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
|
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
|
||||||
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
|
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
|
||||||
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
|
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
|
||||||
.init_with(record.linear2),
|
.init_with(record.linear2),
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
data::{MNISTBatch, MNISTBatcher},
|
data::{MnistBatch, MnistBatcher},
|
||||||
model::{Model, ModelConfig},
|
model::{Model, ModelConfig},
|
||||||
};
|
};
|
||||||
use burn::data::dataset::vision::MNISTDataset;
|
use burn::data::dataset::vision::MnistDataset;
|
||||||
use burn::train::{
|
use burn::train::{
|
||||||
metric::{AccuracyMetric, LossMetric},
|
metric::{AccuracyMetric, LossMetric},
|
||||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||||
|
@ -36,16 +36,16 @@ impl<B: Backend> Model<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||||
let item = self.forward_classification(batch.images, batch.targets);
|
let item = self.forward_classification(batch.images, batch.targets);
|
||||||
|
|
||||||
TrainOutput::new(self, item.loss.backward(), item)
|
TrainOutput::new(self, item.loss.backward(), item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {
|
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||||
self.forward_classification(batch.images, batch.targets)
|
self.forward_classification(batch.images, batch.targets)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,20 +74,20 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
||||||
|
|
||||||
B::seed(config.seed);
|
B::seed(config.seed);
|
||||||
|
|
||||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::train());
|
.build(MnistDataset::train());
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::test());
|
.build(MnistDataset::test());
|
||||||
|
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train_numeric(AccuracyMetric::new())
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
|
|
|
@ -1,26 +1,26 @@
|
||||||
use burn::{
|
use burn::{
|
||||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct MNISTBatcher<B: Backend> {
|
pub struct MnistBatcher<B: Backend> {
|
||||||
device: B::Device,
|
device: B::Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct MNISTBatch<B: Backend> {
|
pub struct MnistBatch<B: Backend> {
|
||||||
pub images: Tensor<B, 3>,
|
pub images: Tensor<B, 3>,
|
||||||
pub targets: Tensor<B, 1, Int>,
|
pub targets: Tensor<B, 1, Int>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MNISTBatcher<B> {
|
impl<B: Backend> MnistBatcher<B> {
|
||||||
pub fn new(device: B::Device) -> Self {
|
pub fn new(device: B::Device) -> Self {
|
||||||
Self { device }
|
Self { device }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||||
let images = items
|
let images = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| Data::<f32, 2>::from(item.image))
|
.map(|item| Data::<f32, 2>::from(item.image))
|
||||||
|
@ -45,6 +45,6 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||||
let images = Tensor::cat(images, 0);
|
let images = Tensor::cat(images, 0);
|
||||||
let targets = Tensor::cat(targets, 0);
|
let targets = Tensor::cat(targets, 0);
|
||||||
|
|
||||||
MNISTBatch { images, targets }
|
MnistBatch { images, targets }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::data::MNISTBatch;
|
use crate::data::MnistBatch;
|
||||||
use burn::{
|
use burn::{
|
||||||
module::Module,
|
module::Module,
|
||||||
nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
|
nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
|
||||||
|
@ -73,7 +73,7 @@ impl<B: Backend> Model<B> {
|
||||||
self.fc2.forward(x)
|
self.fc2.forward(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_classification(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
pub fn forward_classification(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||||
let targets = item.targets;
|
let targets = item.targets;
|
||||||
let output = self.forward(item.images);
|
let output = self.forward(item.images);
|
||||||
let loss = CrossEntropyLossConfig::new()
|
let loss = CrossEntropyLossConfig::new()
|
||||||
|
@ -117,16 +117,16 @@ impl<B: Backend> ConvBlock<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, item: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
fn step(&self, item: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||||
let item = self.forward_classification(item);
|
let item = self.forward_classification(item);
|
||||||
|
|
||||||
TrainOutput::new(self, item.loss.backward(), item)
|
TrainOutput::new(self, item.loss.backward(), item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||||
fn step(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
fn step(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
|
||||||
self.forward_classification(item)
|
self.forward_classification(item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::data::MNISTBatcher;
|
use crate::data::MnistBatcher;
|
||||||
use crate::model::Model;
|
use crate::model::Model;
|
||||||
|
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
|
@ -10,7 +10,7 @@ use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
|
||||||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},
|
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||||
tensor::backend::AutodiffBackend,
|
tensor::backend::AutodiffBackend,
|
||||||
train::{
|
train::{
|
||||||
metric::{AccuracyMetric, LossMetric},
|
metric::{AccuracyMetric, LossMetric},
|
||||||
|
@ -44,19 +44,19 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||||
B::seed(config.seed);
|
B::seed(config.seed);
|
||||||
|
|
||||||
// Data
|
// Data
|
||||||
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
let batcher_train = MnistBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::train());
|
.build(MnistDataset::train());
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(MNISTDataset::test());
|
.build(MnistDataset::test());
|
||||||
|
|
||||||
// Model
|
// Model
|
||||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::env::args;
|
||||||
use burn::backend::ndarray::NdArray;
|
use burn::backend::ndarray::NdArray;
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
|
|
||||||
use burn::data::dataset::vision::MNISTDataset;
|
use burn::data::dataset::vision::MnistDataset;
|
||||||
use burn::data::dataset::Dataset;
|
use burn::data::dataset::Dataset;
|
||||||
|
|
||||||
use onnx_inference::mnist::Model;
|
use onnx_inference::mnist::Model;
|
||||||
|
@ -34,7 +34,7 @@ fn main() {
|
||||||
let model: Model<Backend> = Model::default();
|
let model: Model<Backend> = Model::default();
|
||||||
|
|
||||||
// Load the MNIST dataset and get an item
|
// Load the MNIST dataset and get an item
|
||||||
let dataset = MNISTDataset::test();
|
let dataset = MnistDataset::test();
|
||||||
let item = dataset.get(image_index).unwrap();
|
let item = dataset.get(image_index).unwrap();
|
||||||
|
|
||||||
// Create a tensor from the image data
|
// Create a tensor from the image data
|
||||||
|
|
|
@ -5,7 +5,7 @@ use burn::backend::ndarray::NdArray;
|
||||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
|
|
||||||
use burn::data::dataset::vision::MNISTDataset;
|
use burn::data::dataset::vision::MnistDataset;
|
||||||
use burn::data::dataset::Dataset;
|
use burn::data::dataset::Dataset;
|
||||||
|
|
||||||
use model::Model;
|
use model::Model;
|
||||||
|
@ -42,7 +42,7 @@ fn main() {
|
||||||
let model: Model<Backend> = Model::new_with(record);
|
let model: Model<Backend> = Model::new_with(record);
|
||||||
|
|
||||||
// Load the MNIST dataset and get an item
|
// Load the MNIST dataset and get an item
|
||||||
let dataset = MNISTDataset::test();
|
let dataset = MnistDataset::test();
|
||||||
let item = dataset.get(image_index).unwrap();
|
let item = dataset.get(image_index).unwrap();
|
||||||
|
|
||||||
// Create a tensor from the image data
|
// Create a tensor from the image data
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use crate::dataset::DiabetesBatch;
|
use crate::dataset::DiabetesBatch;
|
||||||
use burn::config::Config;
|
use burn::config::Config;
|
||||||
use burn::nn::loss::Reduction::Mean;
|
use burn::nn::loss::Reduction::Mean;
|
||||||
use burn::nn::ReLU;
|
use burn::nn::Relu;
|
||||||
use burn::{
|
use burn::{
|
||||||
module::Module,
|
module::Module,
|
||||||
nn::{loss::MSELoss, Linear, LinearConfig},
|
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||||
tensor::{
|
tensor::{
|
||||||
backend::{AutodiffBackend, Backend},
|
backend::{AutodiffBackend, Backend},
|
||||||
Tensor,
|
Tensor,
|
||||||
|
@ -16,7 +16,7 @@ use burn::{
|
||||||
pub struct RegressionModel<B: Backend> {
|
pub struct RegressionModel<B: Backend> {
|
||||||
input_layer: Linear<B>,
|
input_layer: Linear<B>,
|
||||||
output_layer: Linear<B>,
|
output_layer: Linear<B>,
|
||||||
activation: ReLU,
|
activation: Relu,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
|
@ -39,7 +39,7 @@ impl RegressionModelConfig {
|
||||||
RegressionModel {
|
RegressionModel {
|
||||||
input_layer,
|
input_layer,
|
||||||
output_layer,
|
output_layer,
|
||||||
activation: ReLU::new(),
|
activation: Relu::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ impl<B: Backend> RegressionModel<B> {
|
||||||
let targets: Tensor<B, 2> = item.targets.unsqueeze();
|
let targets: Tensor<B, 2> = item.targets.unsqueeze();
|
||||||
let output: Tensor<B, 2> = self.forward(item.inputs);
|
let output: Tensor<B, 2> = self.forward(item.inputs);
|
||||||
|
|
||||||
let loss = MSELoss::new().forward(output.clone(), targets.clone(), Mean);
|
let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);
|
||||||
|
|
||||||
RegressionOutput {
|
RegressionOutput {
|
||||||
loss,
|
loss,
|
||||||
|
|
|
@ -19,7 +19,7 @@ use burn::{
|
||||||
record::{CompactRecorder, Recorder},
|
record::{CompactRecorder, Recorder},
|
||||||
tensor::backend::AutodiffBackend,
|
tensor::backend::AutodiffBackend,
|
||||||
train::{
|
train::{
|
||||||
metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric},
|
metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
|
||||||
LearnerBuilder,
|
LearnerBuilder,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -91,8 +91,8 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
||||||
|
|
||||||
// Initialize learner
|
// Initialize learner
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train(CUDAMetric::new())
|
.metric_train(CudaMetric::new())
|
||||||
.metric_valid(CUDAMetric::new())
|
.metric_valid(CudaMetric::new())
|
||||||
.metric_train_numeric(AccuracyMetric::new())
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
.metric_valid_numeric(AccuracyMetric::new())
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
.metric_train_numeric(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
|
|
|
@ -13,7 +13,7 @@ use burn::{
|
||||||
record::{CompactRecorder, DefaultRecorder, Recorder},
|
record::{CompactRecorder, DefaultRecorder, Recorder},
|
||||||
tensor::backend::AutodiffBackend,
|
tensor::backend::AutodiffBackend,
|
||||||
train::{
|
train::{
|
||||||
metric::{AccuracyMetric, CUDAMetric, LearningRateMetric, LossMetric},
|
metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
|
||||||
LearnerBuilder,
|
LearnerBuilder,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -68,8 +68,8 @@ pub fn train<B: AutodiffBackend, D: Dataset<TextGenerationItem> + 'static>(
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train(CUDAMetric::new())
|
.metric_train(CudaMetric::new())
|
||||||
.metric_valid(CUDAMetric::new())
|
.metric_valid(CudaMetric::new())
|
||||||
.metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
.metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||||
.metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
.metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||||
.metric_train(LossMetric::new())
|
.metric_train(LossMetric::new())
|
||||||
|
|
Loading…
Reference in New Issue