3.5 KiB
MNIST Inference on Web
This crate demonstrates how to run an MNIST-trained model in the browser for inference.
Running
-
Build
./build-for-web.sh {backend}
The backend can either be
ndarray
orwgpu
. Note thatwgpu
only works for browsers with support for WebGPU. -
Run the server
./run-server.sh
-
Open the
http://localhost:8000/
in the browser.
Design
The inference components of burn
with the ndarray
backend can be built with #![no_std]
. This
makes it possible to build and run the model with the wasm32-unknown-unknown
target without a
special system library, such as WASI. (See Cargo.toml on how to
include burn dependencies without std
).
For this demo, we use trained parameters (model.bin
) and model (model.rs
) from the
burn
MNIST example.
The inference API for JavaScript is exposed with the help of
wasm-bindgen
's library and tools.
JavaScript (index.js
) is used to transform hand-drawn digits to a format that the inference API
accepts. The transformation includes image cropping, scaling down, and converting it to grayscale
values.
Model
Layers:
- Input Image (28,28, 1ch)
Conv2d
(3x3, 8ch),BatchNorm2d
,Gelu
Conv2d
(3x3, 16ch),BatchNorm2d
,Gelu
Conv2d
(3x3, 24ch),BatchNorm2d
,Gelu
Linear
(11616, 32),Gelu
Linear
(32, 10)- Softmax Output
The total number of parameters is 376,952.
The model is trained with 4 epochs and the final test accuracy is 98.67%.
The training and hyper parameter information in can be found in
burn
MNIST example.
Comparison
The main differentiating factor of this example's approach (compiling rust model into wasm) and
other popular tools, such as TensorFlow.js,
ONNX Runtime JS and
TVM Web is the absence of runtime code. The rust
compiler optimizes and includes only used burn
routines. 1,509,747 bytes out of Wasm's 1,866,491
byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including
burn
's nn
components, the data deserialization library, and math operations).
Future Improvements
There are several planned enhancements in place:
- #202 - Saving model's params in half-precision and loading back in full. This can be half the size of the wasm file.
- #243 - New WebGPU backend would allow computation using GPU in the browser.
- #1271 - WASM SIMD support in NDArray that can speed up computation on CPU.
Acknowledgements
Two online MNIST demos inspired and helped build this demo: MNIST Draw by Marc (@mco-gh) and MNIST Web Demo (no code was copied but helped tremendously with an implementation approach).