burn/examples/mnist-inference-web/README.md

90 lines
3.2 KiB
Markdown
Raw Permalink Normal View History

2023-03-15 20:49:59 +08:00
# MNIST Inference on Web
2023-12-02 02:07:13 +08:00
[![Live Demo](https://img.shields.io/badge/live-demo-brightgreen)](https://burn.dev/demo)
2023-03-15 20:49:59 +08:00
This crate demonstrates how to run an MNIST-trained model in the browser for inference.
## Running
1. Build
```shell
2023-09-29 05:09:58 +08:00
./build-for-web.sh {backend}
2023-03-15 20:49:59 +08:00
```
2023-09-29 05:09:58 +08:00
The backend can either be `ndarray` or `wgpu`. Note that `wgpu` only works for browsers with support for WebGPU.
2023-03-15 20:49:59 +08:00
2. Run the server
```shell
./run-server.sh
```
2023-03-20 23:51:07 +08:00
3. Open the [`http://localhost:8000/`](http://localhost:8000/) in the browser.
2023-03-15 20:49:59 +08:00
## Design
The inference components of `burn` with the `ndarray` backend can be built with `#![no_std]`. This
makes it possible to build and run the model with the `wasm32-unknown-unknown` target without a
special system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to
include burn dependencies without `std`).
For this demo, we use trained parameters (`model.bin`) and model (`model.rs`) from the
2023-12-02 03:33:28 +08:00
[`burn` MNIST example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).
2023-03-15 20:49:59 +08:00
The inference API for JavaScript is exposed with the help of
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.
JavaScript (`index.js`) is used to transform hand-drawn digits to a format that the inference API
accepts. The transformation includes image cropping, scaling down, and converting it to grayscale
values.
## Model
Layers:
1. Input Image (28,28, 1ch)
2. `Conv2d`(3x3, 8ch), `BatchNorm2d`, `Gelu`
3. `Conv2d`(3x3, 16ch), `BatchNorm2d`, `Gelu`
4. `Conv2d`(3x3, 24ch), `BatchNorm2d`, `Gelu`
5. `Linear`(11616, 32), `Gelu`
2023-03-15 20:49:59 +08:00
6. `Linear`(32, 10)
7. Softmax Output
The total number of parameters is 376,952.
2023-03-15 20:49:59 +08:00
The model is trained with 4 epochs and the final test accuracy is 98.67%.
2023-03-15 20:49:59 +08:00
The training and hyper parameter information in can be found in
2023-12-02 03:33:28 +08:00
[`burn` MNIST example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).
2023-03-15 20:49:59 +08:00
## Comparison
The main differentiating factor of this example's approach (compiling rust model into wasm) and
other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),
[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and
[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust
compiler optimizes and includes only used `burn` routines. 1,509,747 bytes out of Wasm's 1,866,491
byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including
2023-03-15 20:49:59 +08:00
`burn`'s `nn` components, the data deserialization library, and math operations).
## Future Improvements
2023-03-20 23:51:07 +08:00
There are several planned enhancements in place:
2023-03-15 20:49:59 +08:00
2023-03-20 23:51:07 +08:00
- [#1271](https://github.com/rust-ndarray/ndarray/issues/1271) -
[WASM SIMD](https://github.com/WebAssembly/simd/blob/master/proposals/simd/SIMD.md) support in
NDArray that can speed up computation on CPU.
2023-03-15 20:49:59 +08:00
2023-03-20 23:51:07 +08:00
## Acknowledgements
2023-03-15 20:49:59 +08:00
Two online MNIST demos inspired and helped build this demo:
[MNIST Draw](https://mco-mnist-draw-rwpxka3zaa-ue.a.run.app/) by Marc (@mco-gh) and
[MNIST Web Demo](https://ufal.mff.cuni.cz/~straka/courses/npfl129/2223/demos/mnist_web.html) (no
code was copied but helped tremendously with an implementation approach).
## Resources
1. [Rust 🦀 and WebAssembly](https://rustwasm.github.io/docs/book/)
2. [wasm-bindgen](https://rustwasm.github.io/wasm-bindgen/)