2023-03-15 20:49:59 +08:00
|
|
|
# MNIST Inference on Web
|
2023-03-14 07:51:32 +08:00
|
|
|
|
2023-03-15 20:49:59 +08:00
|
|
|
[![Demo up](https://img.shields.io/badge/demo-up-brightgreen)](https://burn-rs.github.io/mnist)
|
|
|
|
|
|
|
|
This crate demonstrates how to run an MNIST-trained model in the browser for inference.
|
|
|
|
|
|
|
|
## Running
|
|
|
|
|
|
|
|
1. Build
|
|
|
|
|
|
|
|
```shell
|
|
|
|
./build-for-web.sh
|
|
|
|
```
|
|
|
|
|
|
|
|
2. Run the server
|
|
|
|
|
|
|
|
```shell
|
|
|
|
./run-server.sh
|
|
|
|
```
|
|
|
|
|
|
|
|
3. Open the [`http://[::]:8000/`](http://[::]:8000/) or
|
|
|
|
[`http://localhost:8000/`](http://localhost:8000/) link in the browser.
|
|
|
|
|
|
|
|
## Design
|
|
|
|
|
|
|
|
The inference components of `burn` with the `ndarray` backend can be built with `#![no_std]`. This
|
|
|
|
makes it possible to build and run the model with the `wasm32-unknown-unknown` target without a
|
|
|
|
special system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to
|
|
|
|
include burn dependencies without `std`).
|
|
|
|
|
|
|
|
For this demo, we use trained parameters (`model-6.json.gz`) and model (`model.rs`) from the
|
|
|
|
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
|
|
|
|
|
|
|
|
During the build time `model-6.json.gz` is converted to
|
|
|
|
[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the
|
|
|
|
final wasm output. The MNIST model is initialized with trained weights from memory during the
|
|
|
|
runtime.
|
|
|
|
|
|
|
|
The inference API for JavaScript is exposed with the help of
|
|
|
|
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.
|
|
|
|
|
|
|
|
JavaScript (`index.js`) is used to transform hand-drawn digits to a format that the inference API
|
|
|
|
accepts. The transformation includes image cropping, scaling down, and converting it to grayscale
|
|
|
|
values.
|
|
|
|
|
|
|
|
## Model
|
|
|
|
|
|
|
|
Layers:
|
|
|
|
|
|
|
|
1. Input Image (28,28, 1ch)
|
|
|
|
2. `Conv2d`(3x3, 8ch), `GELU`
|
|
|
|
3. `Conv2d`(3x3, 16ch), `GELU`
|
|
|
|
4. `Conv2d`(3x3, 24ch), `GELU`
|
|
|
|
5. `Linear`(11616, 32), `GELU`
|
|
|
|
6. `Linear`(32, 10)
|
|
|
|
7. Softmax Output
|
|
|
|
|
|
|
|
The total number of parameters is 376,712.
|
|
|
|
|
|
|
|
The model is trained with 6 epochs and the final test accuracy is 98.03%.
|
|
|
|
|
|
|
|
The training and hyper parameter information in can be found in
|
|
|
|
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
|
|
|
|
|
|
|
|
## Comparison
|
|
|
|
|
|
|
|
The main differentiating factor of this example's approach (compiling rust model into wasm) and
|
|
|
|
other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),
|
|
|
|
[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and
|
|
|
|
[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust
|
|
|
|
compiler optimizes and includes only used `burn` routines. 1,507,884 bytes out of Wasm's 1,831,094
|
|
|
|
byte file is the model's parameters. The rest of 323,210 bytes contain all the code (including
|
|
|
|
`burn`'s `nn` components, the data deserialization library, and math operations).
|
|
|
|
|
|
|
|
## Future Improvements
|
|
|
|
|
|
|
|
There are two planned enhancements in place to `burn` :
|
|
|
|
|
|
|
|
- [#201](https://github.com/burn-rs/burn/issues/201) - Saving model's params in binary format. This
|
|
|
|
will simplify the inference code.
|
|
|
|
- [#202](https://github.com/burn-rs/burn/issues/202) - Saving model's params in half-precision and
|
|
|
|
loading back in full. This can be half the size of the wasm file.
|
|
|
|
|
|
|
|
Worth mentioning two future technological developments that can speed up inference in the browser.
|
|
|
|
[WebGPU](https://github.com/gfx-rs/wgpu) backend could be developed to speed up the computation.
|
|
|
|
Also, if NDArray at some point adds
|
|
|
|
[WASM SIMD](https://github.com/WebAssembly/simd/blob/master/proposals/simd/SIMD.md) support,
|
|
|
|
potentially CPU computation can improve as well.
|
|
|
|
|
|
|
|
## Acknowledgement
|
|
|
|
|
|
|
|
Two online MNIST demos inspired and helped build this demo:
|
|
|
|
[MNIST Draw](https://mco-mnist-draw-rwpxka3zaa-ue.a.run.app/) by Marc (@mco-gh) and
|
|
|
|
[MNIST Web Demo](https://ufal.mff.cuni.cz/~straka/courses/npfl129/2223/demos/mnist_web.html) (no
|
|
|
|
code was copied but helped tremendously with an implementation approach).
|
|
|
|
|
|
|
|
## Resources
|
|
|
|
|
|
|
|
1. [Rust 🦀 and WebAssembly](https://rustwasm.github.io/docs/book/)
|
|
|
|
2. [wasm-bindgen](https://rustwasm.github.io/wasm-bindgen/)
|