2023-03-15 20:49:59 +08:00
|
|
|
# MNIST Inference on Web
|
2023-03-14 07:51:32 +08:00
|
|
|
|
2023-12-02 02:07:13 +08:00
|
|
|
[![Live Demo](https://img.shields.io/badge/live-demo-brightgreen)](https://burn.dev/demo)
|
2023-03-15 20:49:59 +08:00
|
|
|
|
|
|
|
This crate demonstrates how to run an MNIST-trained model in the browser for inference.
|
|
|
|
|
|
|
|
## Running
|
|
|
|
|
|
|
|
1. Build
|
|
|
|
|
|
|
|
```shell
|
2023-09-29 05:09:58 +08:00
|
|
|
./build-for-web.sh {backend}
|
2023-03-15 20:49:59 +08:00
|
|
|
```
|
|
|
|
|
2023-09-29 05:09:58 +08:00
|
|
|
The backend can either be `ndarray` or `wgpu`. Note that `wgpu` only works for browsers with support for WebGPU.
|
|
|
|
|
2023-03-15 20:49:59 +08:00
|
|
|
2. Run the server
|
|
|
|
|
|
|
|
```shell
|
|
|
|
./run-server.sh
|
|
|
|
```
|
|
|
|
|
2023-03-20 23:51:07 +08:00
|
|
|
3. Open the [`http://localhost:8000/`](http://localhost:8000/) in the browser.
|
2023-03-15 20:49:59 +08:00
|
|
|
|
|
|
|
## Design
|
|
|
|
|
|
|
|
The inference components of `burn` with the `ndarray` backend can be built with `#![no_std]`. This
|
|
|
|
makes it possible to build and run the model with the `wasm32-unknown-unknown` target without a
|
|
|
|
special system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to
|
|
|
|
include burn dependencies without `std`).
|
|
|
|
|
2023-03-23 23:02:46 +08:00
|
|
|
For this demo, we use trained parameters (`model.bin`) and model (`model.rs`) from the
|
2023-12-02 03:33:28 +08:00
|
|
|
[`burn` MNIST example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).
|
2023-03-15 20:49:59 +08:00
|
|
|
|
|
|
|
The inference API for JavaScript is exposed with the help of
|
|
|
|
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.
|
|
|
|
|
|
|
|
JavaScript (`index.js`) is used to transform hand-drawn digits to a format that the inference API
|
|
|
|
accepts. The transformation includes image cropping, scaling down, and converting it to grayscale
|
|
|
|
values.
|
|
|
|
|
|
|
|
## Model
|
|
|
|
|
|
|
|
Layers:
|
|
|
|
|
|
|
|
1. Input Image (28,28, 1ch)
|
2024-02-16 00:39:06 +08:00
|
|
|
2. `Conv2d`(3x3, 8ch), `BatchNorm2d`, `Gelu`
|
|
|
|
3. `Conv2d`(3x3, 16ch), `BatchNorm2d`, `Gelu`
|
|
|
|
4. `Conv2d`(3x3, 24ch), `BatchNorm2d`, `Gelu`
|
|
|
|
5. `Linear`(11616, 32), `Gelu`
|
2023-03-15 20:49:59 +08:00
|
|
|
6. `Linear`(32, 10)
|
|
|
|
7. Softmax Output
|
|
|
|
|
2023-03-18 07:46:26 +08:00
|
|
|
The total number of parameters is 376,952.
|
2023-03-15 20:49:59 +08:00
|
|
|
|
2023-03-18 07:46:26 +08:00
|
|
|
The model is trained with 4 epochs and the final test accuracy is 98.67%.
|
2023-03-15 20:49:59 +08:00
|
|
|
|
|
|
|
The training and hyper parameter information in can be found in
|
2023-12-02 03:33:28 +08:00
|
|
|
[`burn` MNIST example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).
|
2023-03-15 20:49:59 +08:00
|
|
|
|
|
|
|
## Comparison
|
|
|
|
|
|
|
|
The main differentiating factor of this example's approach (compiling rust model into wasm) and
|
|
|
|
other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),
|
|
|
|
[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and
|
|
|
|
[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust
|
2023-03-18 07:46:26 +08:00
|
|
|
compiler optimizes and includes only used `burn` routines. 1,509,747 bytes out of Wasm's 1,866,491
|
|
|
|
byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including
|
2023-03-15 20:49:59 +08:00
|
|
|
`burn`'s `nn` components, the data deserialization library, and math operations).
|
|
|
|
|
|
|
|
## Future Improvements
|
|
|
|
|
2023-03-20 23:51:07 +08:00
|
|
|
There are several planned enhancements in place:
|
2023-03-15 20:49:59 +08:00
|
|
|
|
2023-03-20 23:51:07 +08:00
|
|
|
- [#1271](https://github.com/rust-ndarray/ndarray/issues/1271) -
|
|
|
|
[WASM SIMD](https://github.com/WebAssembly/simd/blob/master/proposals/simd/SIMD.md) support in
|
|
|
|
NDArray that can speed up computation on CPU.
|
2023-03-15 20:49:59 +08:00
|
|
|
|
2023-03-20 23:51:07 +08:00
|
|
|
## Acknowledgements
|
2023-03-15 20:49:59 +08:00
|
|
|
|
|
|
|
Two online MNIST demos inspired and helped build this demo:
|
|
|
|
[MNIST Draw](https://mco-mnist-draw-rwpxka3zaa-ue.a.run.app/) by Marc (@mco-gh) and
|
|
|
|
[MNIST Web Demo](https://ufal.mff.cuni.cz/~straka/courses/npfl129/2223/demos/mnist_web.html) (no
|
|
|
|
code was copied but helped tremendously with an implementation approach).
|
|
|
|
|
|
|
|
## Resources
|
|
|
|
|
|
|
|
1. [Rust 🦀 and WebAssembly](https://rustwasm.github.io/docs/book/)
|
|
|
|
2. [wasm-bindgen](https://rustwasm.github.io/wasm-bindgen/)
|