mirror of https://github.com/tracel-ai/burn.git
244 lines
7.5 KiB
HTML
244 lines
7.5 KiB
HTML
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="UTF-8" />
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
|
|
<title>Image Classification</title>
|
|
|
|
<script
|
|
src="https://cdn.jsdelivr.net/npm/wasm-feature-detect@1.5.1/dist/umd/index.min.js"
|
|
integrity="sha256-9+AQR2dApXE+f/D998vy0RATN/o4++mqVjAZ3lo432g="
|
|
crossorigin="anonymous"
|
|
></script>
|
|
|
|
<script
|
|
src="https://cdn.jsdelivr.net/npm/chart.js@4.2.1/dist/chart.umd.min.js"
|
|
integrity="sha256-tgiW1vJqfIKxE0F2uVvsXbgUlTyrhPMY/sm30hh/Sxc="
|
|
crossorigin="anonymous"
|
|
></script>
|
|
|
|
<script
|
|
src="https://cdn.jsdelivr.net/npm/chartjs-plugin-datalabels@2.2.0/dist/chartjs-plugin-datalabels.min.js"
|
|
integrity="sha256-IMCPPZxtLvdt9tam8RJ8ABMzn+Mq3SQiInbDmMYwjDg="
|
|
crossorigin="anonymous"
|
|
></script>
|
|
|
|
<script src="./index.js"></script>
|
|
|
|
<link
|
|
rel="stylesheet"
|
|
href="https://cdn.jsdelivr.net/npm/normalize.min.css@8.0.1/normalize.min.css"
|
|
integrity="sha256-oeib74n7OcB5VoyaI+aGxJKkNEdyxYjd2m3fi/3gKls="
|
|
crossorigin="anonymous"
|
|
/>
|
|
|
|
<link rel="stylesheet" href="./index.css" />
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
<div class="selections">
|
|
<!-- Backend Selection -->
|
|
<div class="select-box">
|
|
1.
|
|
<label for="backend">Backend:</label>
|
|
<select id="backend">
|
|
<option value="ndarray" selected>CPU - Ndarray</option>
|
|
<option value="candle">CPU - Candle</option>
|
|
<option value="webgpu">GPU - WebGPU</option>
|
|
</select>
|
|
</div>
|
|
</div>
|
|
|
|
<div class="row-container">
|
|
<!-- Image Selection -->
|
|
<div class="select-box">
|
|
2.
|
|
<select id="imageDropdown">
|
|
<option value="" selected>Select Image</option>
|
|
<option value="samples/bridge.jpg">Bridge</option>
|
|
<option value="samples/cat.jpg">Cat</option>
|
|
<option value="samples/coyote.jpg">Coyote</option>
|
|
<option value="samples/flamingo.jpg">Flamingo</option>
|
|
<option value="samples/pelican.jpg">Pelican</option>
|
|
<option value="samples/table-lamp.jpg">Table Lamp</option>
|
|
<option value="samples/torch.jpg">Torch</option>
|
|
</select>
|
|
or
|
|
<input type="file" id="fileInput" accept="image/*" />
|
|
</div>
|
|
</div>
|
|
|
|
<!-- Time Taken -->
|
|
<div id="time"> </div>
|
|
|
|
<!-- Container for the three boxes -->
|
|
<div class="row-container">
|
|
<!-- Canvas to Display Image -->
|
|
<div class="canvas-box">
|
|
<canvas id="imageCanvas" width="224" height="224"></canvas>
|
|
</div>
|
|
|
|
<!-- Chart -->
|
|
<div class="chart-box">
|
|
<canvas id="chart" width="500" height="224"></canvas>
|
|
</div>
|
|
</div>
|
|
|
|
<!-- Clear Button -->
|
|
<div class="actions">
|
|
<button id="clearButton">Clear</button>
|
|
</div>
|
|
</div>
|
|
|
|
<!-- JavaScript Logic -->
|
|
<script type="module">
|
|
// TODO - Move this to a separate file (index.js)
|
|
|
|
// DOM Elements
|
|
const imgDropdown = $("imageDropdown");
|
|
const backendDropdown = $("backend");
|
|
const fileInput = $("fileInput");
|
|
const canvas = $("imageCanvas");
|
|
const ctx = canvas.getContext("2d", { willReadFrequently: true });
|
|
const clearButton = $("clearButton");
|
|
const time = $("time");
|
|
|
|
const chart = chartConfigBuilder($("chart"));
|
|
|
|
// Event Handlers
|
|
imgDropdown.addEventListener("change", handleImageDropdownChange);
|
|
backendDropdown.addEventListener("change", handleBackendDropdownChange);
|
|
fileInput.addEventListener("change", handleFileInputChange);
|
|
clearButton.addEventListener("click", resetCanvasAndInputs);
|
|
|
|
// Module level variables
|
|
let imageClassifier;
|
|
|
|
async function initWasm() {
|
|
let simdSupported = await wasmFeatureDetect.simd();
|
|
|
|
if (isSafari()) {
|
|
// TODO enable simd for Safari once it works
|
|
// For some reason NDarray backend is not working on Safari with SIMD enabled
|
|
// Got the following error:
|
|
// recursive use of an object detected which would lead to unsafe aliasing in rust
|
|
console.warn("Safari detected. Disabling wasm simd for now ...");
|
|
simdSupported = false;
|
|
}
|
|
|
|
if (simdSupported) {
|
|
console.debug("SIMD is supported");
|
|
} else {
|
|
console.debug("SIMD is not supported");
|
|
}
|
|
|
|
let modulePath = simdSupported
|
|
? "./pkg/simd/image_classification_web.js"
|
|
: "./pkg/no_simd/image_classification_web.js";
|
|
|
|
const { default: wasm, ImageClassifier } = await import(modulePath);
|
|
|
|
wasm().then(() => {
|
|
// Initialize the classifier and save to module level variable
|
|
imageClassifier = new ImageClassifier();
|
|
});
|
|
}
|
|
|
|
initWasm();
|
|
|
|
// Check if WebGPU is supported
|
|
if (!navigator.gpu) {
|
|
backendDropdown.options[2].disabled = true;
|
|
alert("WebGPU is not supported on this device.\n\nDisabling WebGPU backend ...");
|
|
}
|
|
|
|
// Function Definitions
|
|
async function handleImageDropdownChange() {
|
|
if (this.value) {
|
|
await loadImage(this.value);
|
|
}
|
|
|
|
// Reset file input
|
|
fileInput.value = "";
|
|
}
|
|
|
|
async function handleBackendDropdownChange() {
|
|
const backend = this.value;
|
|
if (backend === "ndarray") await imageClassifier.set_backend_ndarray();
|
|
if (backend === "candle") await imageClassifier.set_backend_candle();
|
|
if (backend === "webgpu") await imageClassifier.set_backend_wgpu();
|
|
|
|
resetCanvasAndInputs();
|
|
}
|
|
|
|
function handleFileInputChange() {
|
|
if (this.files && this.files[0]) {
|
|
const reader = new FileReader();
|
|
reader.onload = (event) => loadImage(event.target.result);
|
|
reader.readAsDataURL(this.files[0]);
|
|
|
|
// Reset image dropdown
|
|
imgDropdown.selectedIndex = 0;
|
|
}
|
|
}
|
|
|
|
function resetCanvasAndInputs() {
|
|
// Clear canvas and reset inputs
|
|
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
|
|
|
// Reset dropdowns
|
|
imgDropdown.selectedIndex = 0;
|
|
|
|
// Reset file input
|
|
fileInput.value = "";
|
|
|
|
// Clear chart
|
|
chart.data.labels = ["", "", "", "", ""];
|
|
chart.data.datasets[0].data = [0.0, 0.0, 0.0, 0.0, 0.0];
|
|
chart.update();
|
|
|
|
// Clear time
|
|
time.innerHTML = " ";
|
|
console.log("Cleared canvas");
|
|
}
|
|
|
|
async function loadImage(src) {
|
|
const img = new Image();
|
|
img.src = src;
|
|
|
|
await new Promise((resolve) => {
|
|
img.onload = resolve;
|
|
});
|
|
|
|
clearAndDrawCanvas(img);
|
|
|
|
runInference();
|
|
}
|
|
|
|
async function runInference() {
|
|
const data = extractRGBValuesFromCanvas(canvas, ctx);
|
|
|
|
// Run inference
|
|
const startTime = performance.now();
|
|
const output = await imageClassifier.inference(data);
|
|
const timeTaken = performance.now() - startTime;
|
|
|
|
// Update chart
|
|
const { labels, probabilities } = extractLabelsAndProbabilities(output);
|
|
chart.data.labels = labels;
|
|
chart.data.datasets[0].data = probabilities;
|
|
chart.update();
|
|
|
|
time.innerHTML = `Inference Time: <span> ${toFixed(timeTaken)} </span> ms.`;
|
|
}
|
|
|
|
function clearAndDrawCanvas(img) {
|
|
// Clear canvas
|
|
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
|
ctx.drawImage(img, 0, 0, 224, 224);
|
|
}
|
|
</script>
|
|
</body>
|
|
</html>
|