Add a wasm module for the segment anything example. (#797)

This commit is contained in:
Laurent Mazare 2023-09-10 12:29:37 +01:00 committed by GitHub
parent 6c58fc59fd
commit 584171cae1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 189 additions and 3 deletions

View File

@ -8,6 +8,7 @@ members = [
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
]

View File

@ -122,6 +122,11 @@ impl Sam {
})
}
pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
let img = self.preprocess(img)?.unsqueeze(0)?;
self.image_encoder.forward(&img)
}
pub fn forward(
&self,
img: &Tensor,
@ -131,15 +136,32 @@ impl Sam {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
self.forward_for_embeddings(
&img_embeddings,
original_h,
original_w,
point,
multimask_output,
)
}
pub fn forward_for_embeddings(
&self,
img_embeddings: &Tensor,
original_h: usize,
original_w: usize,
point: Option<(f64, f64)>,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
img.device(),
img_embeddings.device(),
)?;
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
}
};
@ -147,7 +169,7 @@ impl Sam {
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,

View File

@ -0,0 +1,29 @@
[package]
name = "candle-wasm-example-sam"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
num-traits = { workspace = true }
# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
image = { workspace = true }
log = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
wasm-bindgen = "0.2.87"

View File

@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@ -0,0 +1,113 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_wasm_example_sam as sam;
use wasm_bindgen::prelude::*;
#[allow(unused)]
struct Embeddings {
original_width: u32,
original_height: u32,
width: u32,
height: u32,
data: Tensor,
}
#[wasm_bindgen]
pub struct Model {
sam: sam::Sam,
embeddings: Option<Embeddings>,
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let sam = if use_tiny {
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
Ok(Self {
sam,
embeddings: None,
})
}
pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
sam::console_log!("image data: {}", image_data.len());
let image_data = std::io::Cursor::new(image_data);
let image = image::io::Reader::new(image_data)
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;
let (original_height, original_width) = (image.height(), image.width());
let (height, width) = (original_height, original_width);
let resize_longest = sam::IMAGE_SIZE as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
let image_t = {
let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
let data = img.to_rgb8().into_raw();
Tensor::from_vec(
data,
(img.height() as usize, img.width() as usize, 3),
&Device::Cpu,
)?
.permute((2, 0, 1))?
};
let data = self.sam.embeddings(&image_t)?;
self.embeddings = Some(Embeddings {
original_width,
original_height,
width,
height,
data,
});
Ok(())
}
// x and y have to be between 0 and 1
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
let embeddings = match &self.embeddings {
None => todo!(),
Some(embeddings) => embeddings,
};
let (mask, iou_predictions) = self.sam.forward_for_embeddings(
&embeddings.data,
embeddings.height as usize,
embeddings.width as usize,
Some((x, y)),
false,
)?;
let iou = iou_predictions.to_vec1::<f32>()?[0];
let mask_shape = mask.dims().to_vec();
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
let mask = Mask {
iou,
mask_shape,
mask_data,
};
let json = serde_json::to_string(&mask)?;
Ok(json)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct Mask {
iou: f32,
mask_shape: Vec<usize>,
mask_data: Vec<u8>,
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,19 @@
use candle_transformers::models::segment_anything::sam;
use wasm_bindgen::prelude::*;
pub use sam::{Sam, IMAGE_SIZE};
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
}