diff --git a/examples/mnist-inference-web/Cargo.toml b/examples/mnist-inference-web/Cargo.toml index e2ca717da..db28cc2f8 100644 --- a/examples/mnist-inference-web/Cargo.toml +++ b/examples/mnist-inference-web/Cargo.toml @@ -15,6 +15,7 @@ default = [] [dependencies] burn = {path = "../../burn", default-features = false} burn-ndarray = {path = "../../burn-ndarray", default-features = false} +burn-wgpu = {path = "../../burn-wgpu", default-features = false} serde = {workspace = true} wasm-bindgen = "0.2.87" diff --git a/examples/mnist-inference-web/build-for-web.sh b/examples/mnist-inference-web/build-for-web.sh index 7aeec2b8b..90934915c 100755 --- a/examples/mnist-inference-web/build-for-web.sh +++ b/examples/mnist-inference-web/build-for-web.sh @@ -10,7 +10,7 @@ then fi # Set optimization flags -export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3" +export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg=web_sys_unstable_apis" # Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory. mkdir -p pkg diff --git a/examples/mnist-inference-web/src/state.rs b/examples/mnist-inference-web/src/state.rs index 58c19c043..44bfa07fd 100644 --- a/examples/mnist-inference-web/src/state.rs +++ b/examples/mnist-inference-web/src/state.rs @@ -4,8 +4,10 @@ use burn::record::BinBytesRecorder; use burn::record::FullPrecisionSettings; use burn::record::Recorder; use burn_ndarray::NdArrayBackend; +use burn_wgpu::WgpuBackend; -pub type Backend = NdArrayBackend; +pub type Backend = WgpuBackend; +// pub type Backend = NdArrayBackend; static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");