Add image classification web demo with WebGPU, CPU backends (#840)
75
Cargo.toml
|
@ -4,31 +4,32 @@
|
|||
resolver = "2"
|
||||
|
||||
members = [
|
||||
"burn",
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-core",
|
||||
"burn-dataset",
|
||||
"burn-derive",
|
||||
"burn-import",
|
||||
"burn-import/onnx-tests",
|
||||
"burn-ndarray",
|
||||
"burn-no-std-tests",
|
||||
"burn-tch",
|
||||
"burn-wgpu",
|
||||
"burn-candle",
|
||||
"burn-tensor-testgen",
|
||||
"burn-tensor",
|
||||
"burn-train",
|
||||
"xtask",
|
||||
"examples/*",
|
||||
"backend-comparison",
|
||||
"burn",
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-core",
|
||||
"burn-dataset",
|
||||
"burn-derive",
|
||||
"burn-import",
|
||||
"burn-import/onnx-tests",
|
||||
"burn-ndarray",
|
||||
"burn-no-std-tests",
|
||||
"burn-tch",
|
||||
"burn-wgpu",
|
||||
"burn-candle",
|
||||
"burn-tensor-testgen",
|
||||
"burn-tensor",
|
||||
"burn-train",
|
||||
"xtask",
|
||||
"examples/*",
|
||||
"backend-comparison",
|
||||
]
|
||||
|
||||
exclude = ["examples/notebook"]
|
||||
|
||||
[workspace.dependencies]
|
||||
async-trait = "0.1.73"
|
||||
bytemuck = "1.13"
|
||||
const-random = "0.1.15"
|
||||
csv = "1.2.2"
|
||||
|
@ -37,11 +38,12 @@ dirs = "5.0.1"
|
|||
fake = "2.6.1"
|
||||
flate2 = "1.0.26"
|
||||
float-cmp = "0.9.0"
|
||||
getrandom = { version = "0.2.10", default-features = false }
|
||||
gix-tempfile = { version = "8.0.0", features = ["signals"] }
|
||||
hashbrown = "0.14.0"
|
||||
indicatif = "0.17.5"
|
||||
libm = "0.2.7"
|
||||
log = "0.4.19"
|
||||
log = { default-features = false, version = "0.4.19" }
|
||||
pretty_assertions = "1.3"
|
||||
proc-macro2 = "1.0.60"
|
||||
protobuf-codegen = "3.2"
|
||||
|
@ -55,15 +57,17 @@ rusqlite = { version = "0.29" }
|
|||
sanitize-filename = "0.5.0"
|
||||
serde_rusqlite = "0.33.1"
|
||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||
strum = "0.24"
|
||||
strum_macros = "0.24"
|
||||
strum = "0.25.0"
|
||||
strum_macros = "0.25.2"
|
||||
syn = { version = "2.0", features = ["full", "extra-traits"] }
|
||||
tempfile = "3.6.0"
|
||||
thiserror = "1.0.40"
|
||||
tracing-subscriber = "0.3.17"
|
||||
tracing-core = "0.1.31"
|
||||
tracing-appender = "0.2.2"
|
||||
async-trait = "0.1.73"
|
||||
tracing-core = "0.1.31"
|
||||
tracing-subscriber = "0.3.17"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
wasm-logger = "0.2.0"
|
||||
|
||||
# WGPU stuff
|
||||
futures-intrusive = "0.5"
|
||||
|
@ -75,26 +79,27 @@ wgpu = "0.17.0"
|
|||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
bincode = { version = "2.0.0-rc.3", features = [
|
||||
"alloc",
|
||||
"serde",
|
||||
"alloc",
|
||||
"serde",
|
||||
], default-features = false }
|
||||
derive-new = { version = "0.5.9", default-features = false }
|
||||
|
||||
half = { version = "2.3.1", features = [
|
||||
"alloc",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"alloc",
|
||||
"num-traits",
|
||||
"serde",
|
||||
], default-features = false }
|
||||
ndarray = { version = "0.15.6", default-features = false }
|
||||
num-traits = { version = "0.2.15", default-features = false, features = [
|
||||
"libm",
|
||||
"libm",
|
||||
] } # libm is for no_std
|
||||
rand = { version = "0.8.5", default-features = false, features = [
|
||||
"std_rng",
|
||||
"std_rng",
|
||||
] } # std_rng is for no_std
|
||||
rand_distr = { version = "0.4.3", default-features = false }
|
||||
serde = { version = "1.0.164", default-features = false, features = [
|
||||
"derive",
|
||||
"alloc",
|
||||
"derive",
|
||||
"alloc",
|
||||
] } # alloc is for no_std, derive is needed
|
||||
serde_json = { version = "1.0.96", default-features = false }
|
||||
uuid = { version = "1.3.4", default-features = false }
|
||||
|
|
11
_typos.toml
|
@ -1,9 +1,8 @@
|
|||
[default]
|
||||
extend-ignore-identifiers-re = [
|
||||
"ratatui",
|
||||
"NdArray*",
|
||||
"ND"
|
||||
]
|
||||
extend-ignore-identifiers-re = ["ratatui", "NdArray*", "ND"]
|
||||
|
||||
[files]
|
||||
extend-exclude = ["assets/ModuleSerialization.xml"]
|
||||
extend-exclude = [
|
||||
"assets/ModuleSerialization.xml",
|
||||
"examples/image-classification-web/src/model/label.txt",
|
||||
]
|
||||
|
|
|
@ -14,8 +14,8 @@ version = "0.10.0"
|
|||
|
||||
[dependencies]
|
||||
derive-new = { workspace = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.10.0" }
|
||||
half = { workspace = true, features = ["std"] }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = false }
|
||||
half = { workspace = true }
|
||||
# candle-core = { version = "0.1.2" }
|
||||
candle-core = { git = "https://github.com/huggingface/candle", rev = "237323c" }
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ std = ["rand/std"]
|
|||
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
async-trait = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
getrandom = { workspace = true, features = ["js"] }
|
||||
|
||||
[dependencies]
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
|
|
@ -13,17 +13,16 @@ version = "0.10.0"
|
|||
[features]
|
||||
default = ["std", "dataset-minimal"]
|
||||
std = [
|
||||
"burn-common/std",
|
||||
"burn-tensor/std",
|
||||
"flate2",
|
||||
"log",
|
||||
"rand/std",
|
||||
"rmp-serde",
|
||||
"serde/std",
|
||||
"serde_json/std",
|
||||
"bincode/std",
|
||||
"half/std",
|
||||
"derive-new/std",
|
||||
"burn-common/std",
|
||||
"burn-tensor/std",
|
||||
"flate2",
|
||||
"log",
|
||||
"rand/std",
|
||||
"rmp-serde",
|
||||
"serde/std",
|
||||
"serde_json/std",
|
||||
"bincode/std",
|
||||
"half/std",
|
||||
]
|
||||
dataset = ["burn-dataset/default"]
|
||||
dataset-minimal = ["burn-dataset"]
|
||||
|
@ -35,10 +34,18 @@ autodiff = ["burn-autodiff"]
|
|||
|
||||
ndarray = ["__ndarray", "burn-ndarray/default"]
|
||||
ndarray-no-std = ["__ndarray", "burn-ndarray"]
|
||||
ndarray-blas-accelerate = ["__ndarray", "ndarray", "burn-ndarray/blas-accelerate"]
|
||||
ndarray-blas-accelerate = [
|
||||
"__ndarray",
|
||||
"ndarray",
|
||||
"burn-ndarray/blas-accelerate",
|
||||
]
|
||||
ndarray-blas-netlib = ["__ndarray", "ndarray", "burn-ndarray/blas-netlib"]
|
||||
ndarray-blas-openblas = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas"]
|
||||
ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas-system"]
|
||||
ndarray-blas-openblas-system = [
|
||||
"__ndarray",
|
||||
"ndarray",
|
||||
"burn-ndarray/blas-openblas-system",
|
||||
]
|
||||
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
|
||||
|
||||
wgpu = ["burn-wgpu/default"]
|
||||
|
@ -48,8 +55,8 @@ tch = ["burn-tch"]
|
|||
# Serialization formats
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
|
||||
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
|
||||
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
|
||||
|
||||
[dependencies]
|
||||
|
||||
|
@ -66,7 +73,7 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true
|
|||
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
|
||||
|
||||
derive-new = { workspace = true, default-features = false }
|
||||
derive-new = { workspace = true }
|
||||
libm = { workspace = true }
|
||||
log = { workspace = true, optional = true }
|
||||
rand = { workspace = true, features = ["std_rng"] } # Default enables std
|
||||
|
@ -88,7 +95,7 @@ serde_json = { workspace = true, features = ["alloc"] } #Default enables std
|
|||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
burn-dataset = { path = "../burn-dataset", version = "0.10.0", features = [
|
||||
"fake",
|
||||
"fake",
|
||||
] }
|
||||
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", default-features = false }
|
||||
|
|
|
@ -17,8 +17,7 @@ default = ["onnx"]
|
|||
onnx = []
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../burn", version = "0.10.0" }
|
||||
burn-common = {path = "../burn-common", version = "0.10.0" }
|
||||
burn = {path = "../burn", version = "0.10.0"}
|
||||
burn-ndarray = {path = "../burn-ndarray", version = "0.10.0" }
|
||||
|
||||
bytemuck = {workspace = true}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
[package]
|
||||
authors = ["Dilshod Tadjibaev (@antimora)"]
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
name = "image-classification-web"
|
||||
publish = false
|
||||
version = "0.10.0"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn", default-features = false, features = [
|
||||
"ndarray-no-std",
|
||||
"wgpu",
|
||||
] }
|
||||
|
||||
burn-candle = { path = "../../burn-candle", version = "0.10.0", default-features = false }
|
||||
|
||||
js-sys = "0.3.64"
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde-wasm-bindgen = "0.6.0"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
wasm-logger = "0.2.0"
|
||||
wasm-timer = "0.2.5"
|
||||
|
||||
[build-dependencies]
|
||||
# Used to generate code from ONNX model
|
||||
burn-import = { path = "../../burn-import" }
|
|
@ -0,0 +1,52 @@
|
|||
# NOTICES AND INFORMATION
|
||||
|
||||
This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.
|
||||
|
||||
## Sample Images
|
||||
|
||||
Image Title: Domestic cat, a ten month old female.
|
||||
Author: Von.grzanka
|
||||
Source: https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg
|
||||
License: https://creativecommons.org/licenses/by-sa/3.0/
|
||||
|
||||
Image Title: The George Washington Bridge over the Hudson River leading to New York City as seen from Fort Lee, New Jersey.
|
||||
Author: John O'Connell
|
||||
Source: https://commons.wikimedia.org/wiki/File:George_Washington_Bridge_from_New_Jersey-edit.jpg
|
||||
License: https://creativecommons.org/licenses/by/2.0/deed.en
|
||||
|
||||
Image Title: Coyote from Yosemite National Park, California in snow
|
||||
Author: Yathin S Krishnappa
|
||||
Source https://commons.wikimedia.org/wiki/File:2009-Coyote-Yosemite.jpg
|
||||
License: https://creativecommons.org/licenses/by-sa/3.0/deed.en
|
||||
|
||||
Image Title: Table lamp with a lampshade illuminated by sunlight.
|
||||
Author: LoMit
|
||||
Source: https://commons.wikimedia.org/wiki/File:Lamp_with_a_lampshade_illuminated_by_sunlight.jpg
|
||||
License: https://creativecommons.org/licenses/by-sa/4.0/deed.en
|
||||
|
||||
Image Title: White Pelican Pelecanus onocrotalus at Walvis Bay, Namibia
|
||||
Author: Rui Ornelas
|
||||
Source: https://commons.wikimedia.org/wiki/File:Pelikan_Walvis_Bay.jpg
|
||||
License: https://creativecommons.org/licenses/by/2.0/deed.en
|
||||
|
||||
Image Title: Photo of a traditional torch to be posted at gates
|
||||
Author: Faizul Latif Chowdhury
|
||||
Source: https://commons.wikimedia.org/wiki/File:Torch_traditional.jpg
|
||||
License: https://creativecommons.org/licenses/by-sa/3.0/deed.en
|
||||
|
||||
Image Title: American Flamingo Phoenicopterus ruber at Gotomeer, Riscado, Bonaire
|
||||
Author: Paul Asman and Jill Lenoble
|
||||
Source: https://commons.wikimedia.org/wiki/File:Phoenicopterus_ruber_Bonaire_2.jpg
|
||||
License: https://creativecommons.org/licenses/by/2.0/deed.en
|
||||
|
||||
## ONNX Model
|
||||
|
||||
SqueezeNet 1.1 model is licensed under Apache License 2.0. The model is downloaded from the [ONNX model zoo](https://github.com/onnx/models/tree/main).
|
||||
|
||||
Source: https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx
|
||||
License: Apache License 2.0
|
||||
License URL: https://github.com/onnx/models/blob/main/LICENSE
|
||||
|
||||
## ONNX Labels
|
||||
|
||||
The labels for the SqueezeNet 1.1 model are licensed under Apache License 2.0. The labels are downloaded from the [ONNX model zoo](https://github.com/onnx/models/blob/main/vision/classification/synset.txt)
|
|
@ -0,0 +1,74 @@
|
|||
# Image Classification Web Demo Using Burn and WebAssembly
|
||||
|
||||
## Overview
|
||||
|
||||
This demo showcases how to execute an image classification task in a web browser using a model
|
||||
converted to Rust code. The project utilizes the Burn deep learning framework, WebGPU and
|
||||
WebAssembly . Specifically, it demonstrates:
|
||||
|
||||
1. Converting an ONNX (Open Neural Networks Exchange) model into Rust code compatible with the Burn
|
||||
framework.
|
||||
2. Executing the model within a web browser using WebGPU via the `burn-wgpu` backend and WebAssembly
|
||||
through the `burn-ndarray` and `burn-candle` backends.
|
||||
|
||||
## Running the Demo
|
||||
|
||||
### Step 1: Build the WebAssembly Binary and Other Assets
|
||||
|
||||
To compile the Rust code into WebAssembly and build other essential files, execute the following
|
||||
script:
|
||||
|
||||
```bash
|
||||
./build-for-web.sh
|
||||
```
|
||||
|
||||
### Step 2: Launch the Web Server
|
||||
|
||||
Run the following command to initiate a web server on your local machine:
|
||||
|
||||
```bash
|
||||
./run-server.sh
|
||||
```
|
||||
|
||||
### Step 3: Access the Web Demo
|
||||
|
||||
Open your web browser and navigate to:
|
||||
|
||||
```plaintext
|
||||
http://localhost:8000
|
||||
```
|
||||
|
||||
## Backend Compatibility
|
||||
|
||||
As of now, the WebGPU backend is compatible only with Chrome browsers running on macOS and Windows.
|
||||
The application will dynamically detect if WebGPU support is available and proceed accordingly.
|
||||
|
||||
## SIMD Support
|
||||
|
||||
The build targets two sets of binaries, one with SIMD support and one without. The web application
|
||||
dynamically detects if SIMD support is available and downloads the appropriate binary.
|
||||
|
||||
## Model Information
|
||||
|
||||
The image classification task is achieved using the SqueezeNet model, a compact Convolutional Neural
|
||||
Network (CNN). It is trained on the ImageNet dataset and can classify images into 1,000 distinct
|
||||
categories. The included ONNX model is sourced from the
|
||||
[ONNX Model Zoo](https://github.com/onnx/models/tree/main/vision/classification/squeezenet). For
|
||||
further details about the model's architecture and performance, you can refer to the
|
||||
[original paper](https://arxiv.org/abs/1602.07360).
|
||||
|
||||
## Credits
|
||||
|
||||
This demo was inspired by the ONNX Runtime web demo featuring the
|
||||
[SqueezeNet model trained on ImageNet](https://microsoft.github.io/onnxruntime-web-demo/#/squeezenet).
|
||||
|
||||
The complete list of credits/attribution can be found in the [NOTICES](NOTICES.md) file.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Fall back to WebGL if WebGPU is not supported by the browser. See
|
||||
[wgpu's WebGL support ](https://github.com/gfx-rs/wgpu/wiki/Running-on-the-Web-with-WebGPU-and-WebGL)
|
||||
|
||||
- [ ] Enable SIMD support for Safari browsers after Release 179.
|
||||
|
||||
- [ ] Add image paste functionality to allow users to paste an image from the clipboard.
|
|
@ -0,0 +1,18 @@
|
|||
# Add wasm32 target for compiler.
|
||||
rustup target add wasm32-unknown-unknown
|
||||
|
||||
if ! command -v wasm-pack &>/dev/null; then
|
||||
echo "wasm-pack could not be found. Installing ..."
|
||||
cargo install wasm-pack
|
||||
exit
|
||||
fi
|
||||
|
||||
mkdir -p pkg
|
||||
|
||||
echo "Building SIMD version of wasm for web ..."
|
||||
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 -Ctarget-feature=+simd128 --cfg web_sys_unstable_apis"
|
||||
wasm-pack build --dev --out-dir pkg/simd --target web --no-typescript
|
||||
|
||||
echo "Building Non-SIMD version of wasm for web ..."
|
||||
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis"
|
||||
wasm-pack build --dev --out-dir pkg/no_simd --target web --no-typescript
|
|
@ -0,0 +1,75 @@
|
|||
/// This build script generates the model code from the ONNX file and the labels from the text file.
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use burn_import::burn::graph::RecordType;
|
||||
use burn_import::onnx::ModelGen;
|
||||
|
||||
const LABEL_SOURCE_FILE: &str = "src/model/label.txt";
|
||||
const LABEL_DEST_FILE: &str = "model/label.rs";
|
||||
const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx";
|
||||
const OUT_DIR: &str = "model/";
|
||||
|
||||
fn main() {
|
||||
// Re-run the build script if model files change.
|
||||
println!("cargo:rerun-if-changed=src/model");
|
||||
|
||||
// Check if half precision is enabled.
|
||||
let half_precision = cfg!(feature = "half_precision");
|
||||
|
||||
// Generate the model code from the ONNX file.
|
||||
ModelGen::new()
|
||||
.input(INPUT_ONNX_FILE)
|
||||
.out_dir(OUT_DIR)
|
||||
.record_type(RecordType::Bincode)
|
||||
.embed_states(true)
|
||||
.half_precision(half_precision)
|
||||
.run_from_script();
|
||||
|
||||
// Generate the labels from the synset.txt file.
|
||||
generate_labels_from_txt_file().unwrap();
|
||||
}
|
||||
|
||||
/// Read labels from synset.txt and store them in a vector of strings in a Rust file.
|
||||
fn generate_labels_from_txt_file() -> std::io::Result<()> {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE);
|
||||
let mut f = File::create(&dest_path)?;
|
||||
|
||||
let file = File::open(LABEL_SOURCE_FILE)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
writeln!(f, "pub static LABELS: &[&str] = &[")?;
|
||||
for line in reader.lines() {
|
||||
writeln!(
|
||||
f,
|
||||
" \"{}\",",
|
||||
extract_simple_label(line.unwrap()).unwrap()
|
||||
)?;
|
||||
}
|
||||
writeln!(f, "];")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract the simple label from the full label.
|
||||
///
|
||||
/// The full label is of the form: "n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea"
|
||||
/// The simple label is of the form: "indigo bunting"
|
||||
fn extract_simple_label(input: String) -> Option<String> {
|
||||
// Split the string based on the space character.
|
||||
let mut parts = input.split(' ');
|
||||
|
||||
// Skip the first part (the alphanumeric code).
|
||||
parts.next()?;
|
||||
|
||||
// Get the remaining string.
|
||||
let remaining = parts.collect::<Vec<&str>>().join(" ");
|
||||
|
||||
// Find the first comma, if it exists, and take the substring before it.
|
||||
let end_index = remaining.find(',').unwrap_or(remaining.len());
|
||||
|
||||
Some(remaining[0..end_index].to_string())
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
import http.server
|
||||
import socketserver
|
||||
|
||||
PORT = 8000
|
||||
|
||||
Handler = http.server.SimpleHTTPRequestHandler
|
||||
|
||||
with socketserver.TCPServer(("", PORT), Handler) as httpd:
|
||||
print(f"Running local python HTTP server on port {PORT} ...")
|
||||
print(f"Serving HTTP on http://localhost:{PORT}/ ...")
|
||||
httpd.serve_forever()
|
|
@ -0,0 +1,57 @@
|
|||
.container {
|
||||
width: 100%;
|
||||
max-width: 800px;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
.selections {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.select-box {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.file-input-box {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.actions {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
#chart {
|
||||
border: 1px solid #aaa;
|
||||
/* width: 600px; */
|
||||
/* height: 300px; */
|
||||
}
|
||||
|
||||
#imageCanvas {
|
||||
border: 1px solid #aaa;
|
||||
}
|
||||
|
||||
/* Wrapping container for the three boxes */
|
||||
.row-container {
|
||||
display: flex;
|
||||
align-items: center; /* Vertically center the content */
|
||||
justify-content: space-between; /* Distributes space between the items */
|
||||
flex-wrap: wrap; /* Allows the flex items to wrap */
|
||||
}
|
||||
|
||||
.canvas-box,
|
||||
.chart-box {
|
||||
flex: 1; /* Takes up equal width */
|
||||
}
|
||||
|
||||
#time {
|
||||
text-align: center;
|
||||
white-space: nowrap;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
#time > span {
|
||||
font-weight: bold;
|
||||
}
|
|
@ -0,0 +1,243 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Image Classification</title>
|
||||
|
||||
<script
|
||||
src="https://cdn.jsdelivr.net/npm/wasm-feature-detect@1.5.1/dist/umd/index.min.js"
|
||||
integrity="sha256-9+AQR2dApXE+f/D998vy0RATN/o4++mqVjAZ3lo432g="
|
||||
crossorigin="anonymous"
|
||||
></script>
|
||||
|
||||
<script
|
||||
src="https://cdn.jsdelivr.net/npm/chart.js@4.2.1/dist/chart.umd.min.js"
|
||||
integrity="sha256-tgiW1vJqfIKxE0F2uVvsXbgUlTyrhPMY/sm30hh/Sxc="
|
||||
crossorigin="anonymous"
|
||||
></script>
|
||||
|
||||
<script
|
||||
src="https://cdn.jsdelivr.net/npm/chartjs-plugin-datalabels@2.2.0/dist/chartjs-plugin-datalabels.min.js"
|
||||
integrity="sha256-IMCPPZxtLvdt9tam8RJ8ABMzn+Mq3SQiInbDmMYwjDg="
|
||||
crossorigin="anonymous"
|
||||
></script>
|
||||
|
||||
<script src="./index.js"></script>
|
||||
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdn.jsdelivr.net/npm/normalize.min.css@8.0.1/normalize.min.css"
|
||||
integrity="sha256-oeib74n7OcB5VoyaI+aGxJKkNEdyxYjd2m3fi/3gKls="
|
||||
crossorigin="anonymous"
|
||||
/>
|
||||
|
||||
<link rel="stylesheet" href="./index.css" />
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="selections">
|
||||
<!-- Backend Selection -->
|
||||
<div class="select-box">
|
||||
1.
|
||||
<label for="backend">Backend:</label>
|
||||
<select id="backend">
|
||||
<option value="ndarray" selected>CPU - Ndarray</option>
|
||||
<option value="candle">CPU - Candle</option>
|
||||
<option value="webgpu">GPU - WebGPU</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row-container">
|
||||
<!-- Image Selection -->
|
||||
<div class="select-box">
|
||||
2.
|
||||
<select id="imageDropdown">
|
||||
<option value="" selected>Select Image</option>
|
||||
<option value="samples/bridge.jpg">Bridge</option>
|
||||
<option value="samples/cat.jpg">Cat</option>
|
||||
<option value="samples/coyote.jpg">Coyote</option>
|
||||
<option value="samples/flamingo.jpg">Flamingo</option>
|
||||
<option value="samples/pelican.jpg">Pelican</option>
|
||||
<option value="samples/table-lamp.jpg">Table Lamp</option>
|
||||
<option value="samples/torch.jpg">Torch</option>
|
||||
</select>
|
||||
or
|
||||
<input type="file" id="fileInput" accept="image/*" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Time Taken -->
|
||||
<div id="time"> </div>
|
||||
|
||||
<!-- Container for the three boxes -->
|
||||
<div class="row-container">
|
||||
<!-- Canvas to Display Image -->
|
||||
<div class="canvas-box">
|
||||
<canvas id="imageCanvas" width="224" height="224"></canvas>
|
||||
</div>
|
||||
|
||||
<!-- Chart -->
|
||||
<div class="chart-box">
|
||||
<canvas id="chart" width="500" height="224"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Clear Button -->
|
||||
<div class="actions">
|
||||
<button id="clearButton">Clear</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- JavaScript Logic -->
|
||||
<script type="module">
|
||||
// TODO - Move this to a separate file (index.js)
|
||||
|
||||
// DOM Elements
|
||||
const imgDropdown = $("imageDropdown");
|
||||
const backendDropdown = $("backend");
|
||||
const fileInput = $("fileInput");
|
||||
const canvas = $("imageCanvas");
|
||||
const ctx = canvas.getContext("2d", { willReadFrequently: true });
|
||||
const clearButton = $("clearButton");
|
||||
const time = $("time");
|
||||
|
||||
const chart = chartConfigBuilder($("chart"));
|
||||
|
||||
// Event Handlers
|
||||
imgDropdown.addEventListener("change", handleImageDropdownChange);
|
||||
backendDropdown.addEventListener("change", handleBackendDropdownChange);
|
||||
fileInput.addEventListener("change", handleFileInputChange);
|
||||
clearButton.addEventListener("click", resetCanvasAndInputs);
|
||||
|
||||
// Module level variables
|
||||
let imageClassifier;
|
||||
|
||||
async function initWasm() {
|
||||
let simdSupported = await wasmFeatureDetect.simd();
|
||||
|
||||
if (isSafari()) {
|
||||
// TODO enable simd for Safari once it works
|
||||
// For some reason NDarray backend is not working on Safari with SIMD enabled
|
||||
// Got the following error:
|
||||
// recursive use of an object detected which would lead to unsafe aliasing in rust
|
||||
console.warn("Safari detected. Disabling wasm simd for now ...");
|
||||
simdSupported = false;
|
||||
}
|
||||
|
||||
if (simdSupported) {
|
||||
console.debug("SIMD is supported");
|
||||
} else {
|
||||
console.debug("SIMD is not supported");
|
||||
}
|
||||
|
||||
let modulePath = simdSupported
|
||||
? "./pkg/simd/image_classification_web.js"
|
||||
: "./pkg/no_simd/image_classification_web.js";
|
||||
|
||||
const { default: wasm, ImageClassifier } = await import(modulePath);
|
||||
|
||||
wasm().then(() => {
|
||||
// Initialize the classifier and save to module level variable
|
||||
imageClassifier = new ImageClassifier();
|
||||
});
|
||||
}
|
||||
|
||||
initWasm();
|
||||
|
||||
// Check if WebGPU is supported
|
||||
if (!navigator.gpu) {
|
||||
backendDropdown.options[2].disabled = true;
|
||||
alert("WebGPU is not supported on this device.\n\nDisabling WebGPU backend ...");
|
||||
}
|
||||
|
||||
// Function Definitions
|
||||
async function handleImageDropdownChange() {
|
||||
if (this.value) {
|
||||
await loadImage(this.value);
|
||||
}
|
||||
|
||||
// Reset file input
|
||||
fileInput.value = "";
|
||||
}
|
||||
|
||||
async function handleBackendDropdownChange() {
|
||||
const backend = this.value;
|
||||
if (backend === "ndarray") await imageClassifier.set_backend_ndarray();
|
||||
if (backend === "candle") await imageClassifier.set_backend_candle();
|
||||
if (backend === "webgpu") await imageClassifier.set_backend_wgpu();
|
||||
|
||||
resetCanvasAndInputs();
|
||||
}
|
||||
|
||||
function handleFileInputChange() {
|
||||
if (this.files && this.files[0]) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => loadImage(event.target.result);
|
||||
reader.readAsDataURL(this.files[0]);
|
||||
|
||||
// Reset image dropdown
|
||||
imgDropdown.selectedIndex = 0;
|
||||
}
|
||||
}
|
||||
|
||||
function resetCanvasAndInputs() {
|
||||
// Clear canvas and reset inputs
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
|
||||
// Reset dropdowns
|
||||
imgDropdown.selectedIndex = 0;
|
||||
|
||||
// Reset file input
|
||||
fileInput.value = "";
|
||||
|
||||
// Clear chart
|
||||
chart.data.labels = ["", "", "", "", ""];
|
||||
chart.data.datasets[0].data = [0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
chart.update();
|
||||
|
||||
// Clear time
|
||||
time.innerHTML = " ";
|
||||
console.log("Cleared canvas");
|
||||
}
|
||||
|
||||
async function loadImage(src) {
|
||||
const img = new Image();
|
||||
img.src = src;
|
||||
|
||||
await new Promise((resolve) => {
|
||||
img.onload = resolve;
|
||||
});
|
||||
|
||||
clearAndDrawCanvas(img);
|
||||
|
||||
runInference();
|
||||
}
|
||||
|
||||
async function runInference() {
|
||||
const data = extractRGBValuesFromCanvas(canvas, ctx);
|
||||
|
||||
// Run inference
|
||||
const startTime = performance.now();
|
||||
const output = await imageClassifier.inference(data);
|
||||
const timeTaken = performance.now() - startTime;
|
||||
|
||||
// Update chart
|
||||
const { labels, probabilities } = extractLabelsAndProbabilities(output);
|
||||
chart.data.labels = labels;
|
||||
chart.data.datasets[0].data = probabilities;
|
||||
chart.update();
|
||||
|
||||
time.innerHTML = `Inference Time: <span> ${toFixed(timeTaken)} </span> ms.`;
|
||||
}
|
||||
|
||||
function clearAndDrawCanvas(img) {
|
||||
// Clear canvas
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
ctx.drawImage(img, 0, 0, 224, 224);
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
*
|
||||
* This demo is part of Burn project: https://github.com/burn-rs/burn
|
||||
*
|
||||
* Released under a dual license:
|
||||
* https://github.com/burn-rs/burn/blob/main/LICENSE-MIT
|
||||
* https://github.com/burn-rs/burn/blob/main/LICENSE-APACHE
|
||||
*
|
||||
*/
|
||||
|
||||
|
||||
/**
|
||||
* Looks up element by an id.
|
||||
* @param {string} - Element id.
|
||||
*/
|
||||
function $(id) {
|
||||
return document.getElementById(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncates number to a given decimal position
|
||||
* @param {number} num - Number to truncate.
|
||||
* @param {number} fixed - Decimal positions.
|
||||
* src: https://stackoverflow.com/a/11818658
|
||||
*/
|
||||
function toFixed(num, fixed) {
|
||||
const re = new RegExp('^-?\\d+(?:\.\\d{0,' + (fixed || -1) + '})?');
|
||||
return num.toString().match(re)[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function that builds a chart using Chart.js library.
|
||||
* @param {object} chartEl - Chart canvas element.
|
||||
*
|
||||
* NOTE: Assumes chart.js is loaded into the global.
|
||||
*/
|
||||
function chartConfigBuilder(chartEl) {
|
||||
Chart.register(ChartDataLabels);
|
||||
return new Chart(chartEl, {
|
||||
plugins: [ChartDataLabels],
|
||||
type: "bar",
|
||||
data: {
|
||||
labels: ["", "", "", "", "",],
|
||||
datasets: [
|
||||
{
|
||||
data: [0.0, 0.0, 0.0, 0.0, 0.0], // Added one more data point to make it 10
|
||||
borderWidth: 0,
|
||||
fill: true,
|
||||
backgroundColor: "#247ABF",
|
||||
axis: 'y',
|
||||
},
|
||||
],
|
||||
},
|
||||
options: {
|
||||
responsive: false,
|
||||
maintainAspectRatio: false,
|
||||
animation: true,
|
||||
plugins: {
|
||||
legend: {
|
||||
display: false,
|
||||
},
|
||||
tooltip: {
|
||||
enabled: true,
|
||||
},
|
||||
datalabels: {
|
||||
color: "white",
|
||||
formatter: function (value, context) {
|
||||
return toFixed(value, 2);
|
||||
},
|
||||
},
|
||||
},
|
||||
indexAxis: 'y',
|
||||
scales: {
|
||||
y: {
|
||||
},
|
||||
x: {
|
||||
suggestedMin: 0.0,
|
||||
suggestedMax: 1.0,
|
||||
beginAtZero: true,
|
||||
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/** Helper function that extracts labels and probabilities from the data.
|
||||
* @param {object} data - Data object.
|
||||
* @returns {object} - Object with labels and probabilities.
|
||||
*/
|
||||
function extractLabelsAndProbabilities(data) {
|
||||
const labels = [];
|
||||
const probabilities = [];
|
||||
|
||||
for (let item of data) {
|
||||
if (item.hasOwnProperty('label') && item.hasOwnProperty('probability')) {
|
||||
labels.push(item.label);
|
||||
probabilities.push(item.probability);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
labels,
|
||||
probabilities
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function that extracts RGB values from a canvas.
|
||||
* @param {object} canvas - Canvas element.
|
||||
* @param {object} ctx - Canvas context.
|
||||
* @returns {object} - Flattened array of RGB values.
|
||||
*/
|
||||
function extractRGBValuesFromCanvas(canvas, ctx) {
|
||||
// Get image data from the canvas
|
||||
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
||||
|
||||
// Get canvas dimensions
|
||||
const height = canvas.height;
|
||||
const width = canvas.width;
|
||||
|
||||
// Create a flattened array to hold the RGB values in channel-first order
|
||||
const flattenedArray = new Float32Array(3 * height * width);
|
||||
|
||||
// Initialize indices for R, G, B channels in the flattened array
|
||||
let kR = 0,
|
||||
kG = height * width,
|
||||
kB = 2 * height * width;
|
||||
|
||||
for (let y = 0; y < height; y++) {
|
||||
for (let x = 0; x < width; x++) {
|
||||
// Compute the index for the image data array
|
||||
const index = (y * width + x) * 4;
|
||||
|
||||
// Fill in the R, G, B channels in the flattened array
|
||||
flattenedArray[kR++] = imageData.data[index] / 255.0; // Red
|
||||
flattenedArray[kG++] = imageData.data[index + 1] / 255.0; // Green
|
||||
flattenedArray[kB++] = imageData.data[index + 2] / 255.0; // Blue
|
||||
}
|
||||
}
|
||||
|
||||
return flattenedArray;
|
||||
}
|
||||
|
||||
/** Detect if browser is safari
|
||||
* @returns {boolean} - True if browser is safari.
|
||||
*/
|
||||
function isSafari() {
|
||||
// https://stackoverflow.com/questions/7944460/detect-safari-browser
|
||||
let isSafari = /^((?!chrome|android).)*safari/i.test(navigator.userAgent);
|
||||
return isSafari;
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
# Opening index.html file directly by a browser does not work because of
|
||||
# the security restrictions by the browser.
|
||||
|
||||
if ! command -v python3 &>/dev/null; then
|
||||
echo "python3 could not be found. Running server requires python3."
|
||||
exit
|
||||
fi
|
||||
|
||||
python3 https_server.py
|
After Width: | Height: | Size: 324 KiB |
After Width: | Height: | Size: 105 KiB |
After Width: | Height: | Size: 154 KiB |
After Width: | Height: | Size: 795 KiB |
After Width: | Height: | Size: 4.1 MiB |
After Width: | Height: | Size: 6.7 MiB |
After Width: | Height: | Size: 187 KiB |
|
@ -0,0 +1,6 @@
|
|||
#![cfg_attr(not(test), no_std)]
|
||||
|
||||
pub mod model;
|
||||
pub mod web;
|
||||
|
||||
extern crate alloc;
|
|
@ -0,0 +1,2 @@
|
|||
// Generated labels from labels.txt
|
||||
include!(concat!(env!("OUT_DIR"), "/model/label.rs"));
|
|
@ -0,0 +1,3 @@
|
|||
pub mod label;
|
||||
pub mod normalizer;
|
||||
pub mod squeezenet1;
|
|
@ -0,0 +1,38 @@
|
|||
use burn::tensor::{backend::Backend, Tensor};
|
||||
|
||||
// Values are taken from the [ONNX SqueezeNet]
|
||||
// (https://github.com/onnx/models/tree/main/vision/classification/squeezenet#preprocessing)
|
||||
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||
const STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||
|
||||
/// Normalizer for the imagenet dataset.
|
||||
pub struct Normalizer<B: Backend> {
|
||||
pub mean: Tensor<B, 4>,
|
||||
pub std: Tensor<B, 4>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Normalizer<B> {
|
||||
/// Creates a new normalizer.
|
||||
pub fn new() -> Self {
|
||||
let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]);
|
||||
let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]);
|
||||
Self { mean, std }
|
||||
}
|
||||
|
||||
/// Normalizes the input image according to the imagenet dataset.
|
||||
///
|
||||
/// The input image should be in the range [0, 1].
|
||||
/// The output image will be in the range [-1, 1].
|
||||
///
|
||||
/// The normalization is done according to the following formula:
|
||||
/// `input = (input - mean) / std`
|
||||
pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
(input - self.mean.clone()) / self.std.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for Normalizer<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// Generated model from squeezenet1.onnx
|
||||
mod internal_model {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs"));
|
||||
}
|
||||
|
||||
pub use internal_model::*;
|
|
@ -0,0 +1,191 @@
|
|||
#![allow(clippy::new_without_default)]
|
||||
|
||||
use alloc::{
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
use core::convert::Into;
|
||||
|
||||
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet1::Model as SqueezenetModel};
|
||||
|
||||
use burn::{
|
||||
backend::{
|
||||
wgpu::{compute::init_async, AutoGraphicsApi, WgpuBackend, WgpuDevice},
|
||||
NdArrayBackend,
|
||||
},
|
||||
tensor::{activation::softmax, backend::Backend, Tensor},
|
||||
};
|
||||
use burn_candle::CandleBackend;
|
||||
|
||||
use serde::Serialize;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_timer::Instant;
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
/// The model is loaded to a specific backend
|
||||
pub enum ModelType {
|
||||
/// The model is loaded to the Candle backend
|
||||
WithCandleBackend(Model<CandleBackend<f32, i64>>),
|
||||
|
||||
/// The model is loaded to the NdArray backend
|
||||
WithNdarrayBackend(Model<NdArrayBackend<f32>>),
|
||||
|
||||
/// The model is loaded to the Wgpu backend
|
||||
WithWgpuBackend(Model<WgpuBackend<AutoGraphicsApi, f32, i32>>),
|
||||
}
|
||||
|
||||
/// The image is 224x224 pixels with 3 channels (RGB)
|
||||
const HEIGHT: usize = 224;
|
||||
const WIDTH: usize = 224;
|
||||
const CHANNELS: usize = 3;
|
||||
|
||||
/// The image classifier
|
||||
#[wasm_bindgen]
|
||||
pub struct ImageClassifier {
|
||||
model: ModelType,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ImageClassifier {
|
||||
/// Constructor called by JavaScripts with the new keyword.
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
// Initialize the logger so that the logs are printed to the console
|
||||
wasm_logger::init(wasm_logger::Config::default());
|
||||
|
||||
log::info!("Initializing the image classifier");
|
||||
|
||||
Self {
|
||||
model: ModelType::WithNdarrayBackend(Model::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs inference on the image
|
||||
pub async fn inference(&self, input: &[f32]) -> Result<JsValue, JsValue> {
|
||||
log::info!("Running inference on the image");
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let result = match self.model {
|
||||
ModelType::WithCandleBackend(ref model) => model.forward(input).await,
|
||||
ModelType::WithNdarrayBackend(ref model) => model.forward(input).await,
|
||||
ModelType::WithWgpuBackend(ref model) => model.forward(input).await,
|
||||
};
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
log::debug!("Inference is completed in {:?}", duration);
|
||||
|
||||
top_5_classes(result)
|
||||
}
|
||||
|
||||
/// Sets the backend to Candle
|
||||
pub async fn set_backend_candle(&mut self) -> Result<(), JsValue> {
|
||||
log::info!("Loading the model to the Candle backend");
|
||||
let start = Instant::now();
|
||||
self.model = ModelType::WithCandleBackend(Model::new());
|
||||
let duration = start.elapsed();
|
||||
log::debug!("Model is loaded to the Candle backend in {:?}", duration);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the backend to NdArray
|
||||
pub async fn set_backend_ndarray(&mut self) -> Result<(), JsValue> {
|
||||
log::info!("Loading the model to the NdArray backend");
|
||||
let start = Instant::now();
|
||||
self.model = ModelType::WithNdarrayBackend(Model::new());
|
||||
let duration = start.elapsed();
|
||||
log::debug!("Model is loaded to the NdArray backend in {:?}", duration);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the backend to Wgpu
|
||||
pub async fn set_backend_wgpu(&mut self) -> Result<(), JsValue> {
|
||||
log::info!("Loading the model to the Wgpu backend");
|
||||
let start = Instant::now();
|
||||
init_async::<AutoGraphicsApi>(&WgpuDevice::default()).await;
|
||||
self.model = ModelType::WithWgpuBackend(Model::new());
|
||||
let duration = start.elapsed();
|
||||
log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);
|
||||
|
||||
log::debug!("Warming up the model");
|
||||
let start = Instant::now();
|
||||
let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await;
|
||||
let duration = start.elapsed();
|
||||
log::debug!("Warming up is completed in {:?}", duration);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// The image classifier model
|
||||
pub struct Model<B: Backend> {
|
||||
model: SqueezenetModel<B>,
|
||||
normalizer: Normalizer<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
/// Constructor
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
model: SqueezenetModel::from_embedded(),
|
||||
normalizer: Normalizer::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalizes input and runs inference on the image
|
||||
pub async fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
// Reshape from the 1D array to 3d tensor [ width, height, channels]
|
||||
let input: Tensor<B, 4> = Tensor::from_floats(input).reshape([1, CHANNELS, HEIGHT, WIDTH]);
|
||||
|
||||
// Normalize input: make between [-1,1] and make the mean=0 and std=1
|
||||
let input = self.normalizer.normalize(input);
|
||||
|
||||
// Run the tensor input through the model
|
||||
let output = self.model.forward(input);
|
||||
|
||||
// Convert the model output into probability distribution using softmax formula
|
||||
let probabilies = softmax(output, 1);
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let result = probabilies.into_data().convert::<f32>().value;
|
||||
|
||||
// Forces the result to be computed
|
||||
#[cfg(target_family = "wasm")]
|
||||
let result = probabilies.into_data().await.convert::<f32>().value;
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[derive(Serialize)]
|
||||
pub struct InferenceResult {
|
||||
index: usize,
|
||||
probability: f32,
|
||||
label: String,
|
||||
}
|
||||
|
||||
/// Returns the top 5 classes and convert them into a JsValue
|
||||
fn top_5_classes(probabilies: Vec<f32>) -> Result<JsValue, JsValue> {
|
||||
// Convert the probabilities into a vector of (index, probability)
|
||||
let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect();
|
||||
|
||||
// Sort the probabilities in descending order
|
||||
probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
|
||||
// Take the top 5 probabilities
|
||||
probabilies.truncate(5);
|
||||
|
||||
// Convert the probabilities into InferenceResult
|
||||
let result: Vec<InferenceResult> = probabilies
|
||||
.into_iter()
|
||||
.map(|(index, probability)| InferenceResult {
|
||||
index,
|
||||
probability: *probability,
|
||||
label: LABELS[index].to_string(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert the InferenceResult into a JsValue
|
||||
Ok(serde_wasm_bindgen::to_value(&result)?)
|
||||
}
|