burn/examples/mnist-inference-web/README.md

101 lines
3.8 KiB
Markdown
Raw Normal View History

2023-03-15 20:49:59 +08:00
# MNIST Inference on Web
2023-03-15 20:49:59 +08:00
[![Demo up](https://img.shields.io/badge/demo-up-brightgreen)](https://burn-rs.github.io/mnist)
This crate demonstrates how to run an MNIST-trained model in the browser for inference.
## Running
1. Build
```shell
./build-for-web.sh
```
2. Run the server
```shell
./run-server.sh
```
3. Open the [`http://[::]:8000/`](http://[::]:8000/) or
[`http://localhost:8000/`](http://localhost:8000/) link in the browser.
## Design
The inference components of `burn` with the `ndarray` backend can be built with `#![no_std]`. This
makes it possible to build and run the model with the `wasm32-unknown-unknown` target without a
special system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to
include burn dependencies without `std`).
For this demo, we use trained parameters (`model-6.json.gz`) and model (`model.rs`) from the
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
During the build time `model-6.json.gz` is converted to
[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the
final wasm output. The MNIST model is initialized with trained weights from memory during the
runtime.
The inference API for JavaScript is exposed with the help of
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.
JavaScript (`index.js`) is used to transform hand-drawn digits to a format that the inference API
accepts. The transformation includes image cropping, scaling down, and converting it to grayscale
values.
## Model
Layers:
1. Input Image (28,28, 1ch)
2. `Conv2d`(3x3, 8ch), `GELU`
3. `Conv2d`(3x3, 16ch), `GELU`
4. `Conv2d`(3x3, 24ch), `GELU`
5. `Linear`(11616, 32), `GELU`
6. `Linear`(32, 10)
7. Softmax Output
The total number of parameters is 376,712.
The model is trained with 6 epochs and the final test accuracy is 98.03%.
The training and hyper parameter information in can be found in
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
## Comparison
The main differentiating factor of this example's approach (compiling rust model into wasm) and
other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),
[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and
[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust
compiler optimizes and includes only used `burn` routines. 1,507,884 bytes out of Wasm's 1,831,094
byte file is the model's parameters. The rest of 323,210 bytes contain all the code (including
`burn`'s `nn` components, the data deserialization library, and math operations).
## Future Improvements
There are two planned enhancements in place to `burn` :
- [#201](https://github.com/burn-rs/burn/issues/201) - Saving model's params in binary format. This
will simplify the inference code.
- [#202](https://github.com/burn-rs/burn/issues/202) - Saving model's params in half-precision and
loading back in full. This can be half the size of the wasm file.
Worth mentioning two future technological developments that can speed up inference in the browser.
[WebGPU](https://github.com/gfx-rs/wgpu) backend could be developed to speed up the computation.
Also, if NDArray at some point adds
[WASM SIMD](https://github.com/WebAssembly/simd/blob/master/proposals/simd/SIMD.md) support,
potentially CPU computation can improve as well.
## Acknowledgement
Two online MNIST demos inspired and helped build this demo:
[MNIST Draw](https://mco-mnist-draw-rwpxka3zaa-ue.a.run.app/) by Marc (@mco-gh) and
[MNIST Web Demo](https://ufal.mff.cuni.cz/~straka/courses/npfl129/2223/demos/mnist_web.html) (no
code was copied but helped tremendously with an implementation approach).
## Resources
1. [Rust 🦀 and WebAssembly](https://rustwasm.github.io/docs/book/)
2. [wasm-bindgen](https://rustwasm.github.io/wasm-bindgen/)