Fix tensor data elem type conversion in book (#2211)

This commit is contained in:
Guillaume Lagrange 2024-08-28 10:55:10 -04:00 committed by GitHub
parent 0292967000
commit 40d321cc0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 8 deletions

View File

@ -68,8 +68,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> { fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items let images = items
.iter() .iter()
.map(|item| TensorData::from(item.image)) .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) .map(|data| Tensor::<B, 2>::from_data(data, &self.device))
.map(|tensor| tensor.reshape([1, 28, 28])) .map(|tensor| tensor.reshape([1, 28, 28]))
// Normalize: make between [0,1] and make the mean=0 and std=1 // Normalize: make between [0,1] and make the mean=0 and std=1
// values mean=0.1307,std=0.3081 are from the PyTorch MNIST example // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
@ -119,8 +119,8 @@ images.
```rust, ignore ```rust, ignore
let images = items // take items Vec<MnistItem> let images = items // take items Vec<MnistItem>
.iter() // create an iterator over it .iter() // create an iterator over it
.map(|item| TensorData::from(item.image)) // for each item, convert the image to float32 data struct .map(|item| TensorData::from(item.image).convert::<B::FloatElem>()) // for each item, convert the image to float data struct
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device .map(|data| Tensor::<B, 2>::from_data(data, &self.device)) // for each data struct, create a tensor on the device
.map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W] .map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W]
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization
.collect(); // consume the resulting iterator & collect the values into a new vector .collect(); // consume the resulting iterator & collect the values into a new vector
@ -138,5 +138,6 @@ a targets tensor that contains the indexes of the correct digit class. The first
the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate
tensor storage information without being specific for a backend. When creating a tensor from data, tensor storage information without being specific for a backend. When creating a tensor from data,
we often need to convert the data precision to the current backend in use. This can be done with the we often need to convert the data precision to the current backend in use. This can be done with the
`.convert()` method. While importing the `burn::tensor::ElementConversion` trait, you can call `.convert()` method (in this example, the data is converted backend's float element type
`.elem()` on a specific number to convert it to the current backend element type in use. `B::FloatElem`). While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()`
on a specific number to convert it to the current backend element type in use.

View File

@ -24,8 +24,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> { fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items let images = items
.iter() .iter()
.map(|item| TensorData::from(item.image)) .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device)) .map(|data| Tensor::<B, 2>::from_data(data, &self.device))
.map(|tensor| tensor.reshape([1, 28, 28])) .map(|tensor| tensor.reshape([1, 28, 28]))
// normalize: make between [0,1] and make the mean = 0 and std = 1 // normalize: make between [0,1] and make the mean = 0 and std = 1
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example