mirror of https://github.com/tracel-ai/burn.git
Fix tensor data elem type conversion in book (#2211)
This commit is contained in:
parent
0292967000
commit
40d321cc0d
|
@ -68,8 +68,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||||
let images = items
|
let images = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| TensorData::from(item.image))
|
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
|
||||||
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device))
|
.map(|data| Tensor::<B, 2>::from_data(data, &self.device))
|
||||||
.map(|tensor| tensor.reshape([1, 28, 28]))
|
.map(|tensor| tensor.reshape([1, 28, 28]))
|
||||||
// Normalize: make between [0,1] and make the mean=0 and std=1
|
// Normalize: make between [0,1] and make the mean=0 and std=1
|
||||||
// values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
|
// values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
|
||||||
|
@ -119,8 +119,8 @@ images.
|
||||||
```rust, ignore
|
```rust, ignore
|
||||||
let images = items // take items Vec<MnistItem>
|
let images = items // take items Vec<MnistItem>
|
||||||
.iter() // create an iterator over it
|
.iter() // create an iterator over it
|
||||||
.map(|item| TensorData::from(item.image)) // for each item, convert the image to float32 data struct
|
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>()) // for each item, convert the image to float data struct
|
||||||
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device
|
.map(|data| Tensor::<B, 2>::from_data(data, &self.device)) // for each data struct, create a tensor on the device
|
||||||
.map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W]
|
.map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W]
|
||||||
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization
|
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization
|
||||||
.collect(); // consume the resulting iterator & collect the values into a new vector
|
.collect(); // consume the resulting iterator & collect the values into a new vector
|
||||||
|
@ -138,5 +138,6 @@ a targets tensor that contains the indexes of the correct digit class. The first
|
||||||
the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate
|
the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate
|
||||||
tensor storage information without being specific for a backend. When creating a tensor from data,
|
tensor storage information without being specific for a backend. When creating a tensor from data,
|
||||||
we often need to convert the data precision to the current backend in use. This can be done with the
|
we often need to convert the data precision to the current backend in use. This can be done with the
|
||||||
`.convert()` method. While importing the `burn::tensor::ElementConversion` trait, you can call
|
`.convert()` method (in this example, the data is converted backend's float element type
|
||||||
`.elem()` on a specific number to convert it to the current backend element type in use.
|
`B::FloatElem`). While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()`
|
||||||
|
on a specific number to convert it to the current backend element type in use.
|
||||||
|
|
|
@ -24,8 +24,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
|
||||||
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
|
||||||
let images = items
|
let images = items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| TensorData::from(item.image))
|
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
|
||||||
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device))
|
.map(|data| Tensor::<B, 2>::from_data(data, &self.device))
|
||||||
.map(|tensor| tensor.reshape([1, 28, 28]))
|
.map(|tensor| tensor.reshape([1, 28, 28]))
|
||||||
// normalize: make between [0,1] and make the mean = 0 and std = 1
|
// normalize: make between [0,1] and make the mean = 0 and std = 1
|
||||||
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
|
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
|
||||||
|
|
Loading…
Reference in New Issue