burn/examples/image-classification-web/index.html

244 lines
7.5 KiB
HTML

<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Image Classification</title>
<script
src="https://cdn.jsdelivr.net/npm/wasm-feature-detect@1.5.1/dist/umd/index.min.js"
integrity="sha256-9+AQR2dApXE+f/D998vy0RATN/o4++mqVjAZ3lo432g="
crossorigin="anonymous"
></script>
<script
src="https://cdn.jsdelivr.net/npm/chart.js@4.2.1/dist/chart.umd.min.js"
integrity="sha256-tgiW1vJqfIKxE0F2uVvsXbgUlTyrhPMY/sm30hh/Sxc="
crossorigin="anonymous"
></script>
<script
src="https://cdn.jsdelivr.net/npm/chartjs-plugin-datalabels@2.2.0/dist/chartjs-plugin-datalabels.min.js"
integrity="sha256-IMCPPZxtLvdt9tam8RJ8ABMzn+Mq3SQiInbDmMYwjDg="
crossorigin="anonymous"
></script>
<script src="./index.js"></script>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/normalize.min.css@8.0.1/normalize.min.css"
integrity="sha256-oeib74n7OcB5VoyaI+aGxJKkNEdyxYjd2m3fi/3gKls="
crossorigin="anonymous"
/>
<link rel="stylesheet" href="./index.css" />
</head>
<body>
<div class="container">
<div class="selections">
<!-- Backend Selection -->
<div class="select-box">
1.
<label for="backend">Backend:</label>
<select id="backend">
<option value="ndarray" selected>CPU - Ndarray</option>
<option value="candle">CPU - Candle</option>
<option value="webgpu">GPU - WebGPU</option>
</select>
</div>
</div>
<div class="row-container">
<!-- Image Selection -->
<div class="select-box">
2.
<select id="imageDropdown">
<option value="" selected>Select Image</option>
<option value="samples/bridge.jpg">Bridge</option>
<option value="samples/cat.jpg">Cat</option>
<option value="samples/coyote.jpg">Coyote</option>
<option value="samples/flamingo.jpg">Flamingo</option>
<option value="samples/pelican.jpg">Pelican</option>
<option value="samples/table-lamp.jpg">Table Lamp</option>
<option value="samples/torch.jpg">Torch</option>
</select>
or
<input type="file" id="fileInput" accept="image/*" />
</div>
</div>
<!-- Time Taken -->
<div id="time">&nbsp;</div>
<!-- Container for the three boxes -->
<div class="row-container">
<!-- Canvas to Display Image -->
<div class="canvas-box">
<canvas id="imageCanvas" width="224" height="224"></canvas>
</div>
<!-- Chart -->
<div class="chart-box">
<canvas id="chart" width="500" height="224"></canvas>
</div>
</div>
<!-- Clear Button -->
<div class="actions">
<button id="clearButton">Clear</button>
</div>
</div>
<!-- JavaScript Logic -->
<script type="module">
// TODO - Move this to a separate file (index.js)
// DOM Elements
const imgDropdown = $("imageDropdown");
const backendDropdown = $("backend");
const fileInput = $("fileInput");
const canvas = $("imageCanvas");
const ctx = canvas.getContext("2d", { willReadFrequently: true });
const clearButton = $("clearButton");
const time = $("time");
const chart = chartConfigBuilder($("chart"));
// Event Handlers
imgDropdown.addEventListener("change", handleImageDropdownChange);
backendDropdown.addEventListener("change", handleBackendDropdownChange);
fileInput.addEventListener("change", handleFileInputChange);
clearButton.addEventListener("click", resetCanvasAndInputs);
// Module level variables
let imageClassifier;
async function initWasm() {
let simdSupported = await wasmFeatureDetect.simd();
if (isSafari()) {
// TODO enable simd for Safari once it works
// For some reason NDarray backend is not working on Safari with SIMD enabled
// Got the following error:
// recursive use of an object detected which would lead to unsafe aliasing in rust
console.warn("Safari detected. Disabling wasm simd for now ...");
simdSupported = false;
}
if (simdSupported) {
console.debug("SIMD is supported");
} else {
console.debug("SIMD is not supported");
}
let modulePath = simdSupported
? "./pkg/simd/image_classification_web.js"
: "./pkg/no_simd/image_classification_web.js";
const { default: wasm, ImageClassifier } = await import(modulePath);
wasm().then(() => {
// Initialize the classifier and save to module level variable
imageClassifier = new ImageClassifier();
});
}
initWasm();
// Check if WebGPU is supported
if (!navigator.gpu) {
backendDropdown.options[2].disabled = true;
alert("WebGPU is not supported on this device.\n\nDisabling WebGPU backend ...");
}
// Function Definitions
async function handleImageDropdownChange() {
if (this.value) {
await loadImage(this.value);
}
// Reset file input
fileInput.value = "";
}
async function handleBackendDropdownChange() {
const backend = this.value;
if (backend === "ndarray") await imageClassifier.set_backend_ndarray();
if (backend === "candle") await imageClassifier.set_backend_candle();
if (backend === "webgpu") await imageClassifier.set_backend_wgpu();
resetCanvasAndInputs();
}
function handleFileInputChange() {
if (this.files && this.files[0]) {
const reader = new FileReader();
reader.onload = (event) => loadImage(event.target.result);
reader.readAsDataURL(this.files[0]);
// Reset image dropdown
imgDropdown.selectedIndex = 0;
}
}
function resetCanvasAndInputs() {
// Clear canvas and reset inputs
ctx.clearRect(0, 0, canvas.width, canvas.height);
// Reset dropdowns
imgDropdown.selectedIndex = 0;
// Reset file input
fileInput.value = "";
// Clear chart
chart.data.labels = ["", "", "", "", ""];
chart.data.datasets[0].data = [0.0, 0.0, 0.0, 0.0, 0.0];
chart.update();
// Clear time
time.innerHTML = " ";
console.log("Cleared canvas");
}
async function loadImage(src) {
const img = new Image();
img.src = src;
await new Promise((resolve) => {
img.onload = resolve;
});
clearAndDrawCanvas(img);
runInference();
}
async function runInference() {
const data = extractRGBValuesFromCanvas(canvas, ctx);
// Run inference
const startTime = performance.now();
const output = await imageClassifier.inference(data);
const timeTaken = performance.now() - startTime;
// Update chart
const { labels, probabilities } = extractLabelsAndProbabilities(output);
chart.data.labels = labels;
chart.data.datasets[0].data = probabilities;
chart.update();
time.innerHTML = `Inference Time: <span> ${toFixed(timeTaken)} </span> ms.`;
}
function clearAndDrawCanvas(img) {
// Clear canvas
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(img, 0, 0, 224, 224);
}
</script>
</body>
</html>