Add a wasm module for the segment anything example. (#797)
This commit is contained in:
parent
6c58fc59fd
commit
584171cae1
|
@ -8,6 +8,7 @@ members = [
|
|||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
]
|
||||
|
|
|
@ -122,6 +122,11 @@ impl Sam {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
self.image_encoder.forward(&img)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
|
@ -131,15 +136,32 @@ impl Sam {
|
|||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
self.forward_for_embeddings(
|
||||
&img_embeddings,
|
||||
original_h,
|
||||
original_w,
|
||||
point,
|
||||
multimask_output,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn forward_for_embeddings(
|
||||
&self,
|
||||
img_embeddings: &Tensor,
|
||||
original_h: usize,
|
||||
original_w: usize,
|
||||
point: Option<(f64, f64)>,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = match point {
|
||||
None => None,
|
||||
Some((x, y)) => {
|
||||
let points = Tensor::new(
|
||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||
img.device(),
|
||||
img_embeddings.device(),
|
||||
)?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
|
||||
Some((points, labels))
|
||||
}
|
||||
};
|
||||
|
@ -147,7 +169,7 @@ impl Sam {
|
|||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(points, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "candle-wasm-example-sam"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
image = { workspace = true }
|
||||
log = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen = "0.2.87"
|
|
@ -0,0 +1,2 @@
|
|||
cargo build --target wasm32-unknown-unknown --release
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
|
|
@ -0,0 +1,113 @@
|
|||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_wasm_example_sam as sam;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[allow(unused)]
|
||||
struct Embeddings {
|
||||
original_width: u32,
|
||||
original_height: u32,
|
||||
width: u32,
|
||||
height: u32,
|
||||
data: Tensor,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct Model {
|
||||
sam: sam::Sam,
|
||||
embeddings: Option<Embeddings>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let sam = if use_tiny {
|
||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||
};
|
||||
Ok(Self {
|
||||
sam,
|
||||
embeddings: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
|
||||
sam::console_log!("image data: {}", image_data.len());
|
||||
let image_data = std::io::Cursor::new(image_data);
|
||||
let image = image::io::Reader::new(image_data)
|
||||
.with_guessed_format()?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let (original_height, original_width) = (image.height(), image.width());
|
||||
let (height, width) = (original_height, original_width);
|
||||
let resize_longest = sam::IMAGE_SIZE as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
let image_t = {
|
||||
let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
|
||||
let data = img.to_rgb8().into_raw();
|
||||
Tensor::from_vec(
|
||||
data,
|
||||
(img.height() as usize, img.width() as usize, 3),
|
||||
&Device::Cpu,
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
let data = self.sam.embeddings(&image_t)?;
|
||||
self.embeddings = Some(Embeddings {
|
||||
original_width,
|
||||
original_height,
|
||||
width,
|
||||
height,
|
||||
data,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// x and y have to be between 0 and 1
|
||||
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
|
||||
let embeddings = match &self.embeddings {
|
||||
None => todo!(),
|
||||
Some(embeddings) => embeddings,
|
||||
};
|
||||
let (mask, iou_predictions) = self.sam.forward_for_embeddings(
|
||||
&embeddings.data,
|
||||
embeddings.height as usize,
|
||||
embeddings.width as usize,
|
||||
Some((x, y)),
|
||||
false,
|
||||
)?;
|
||||
let iou = iou_predictions.to_vec1::<f32>()?[0];
|
||||
let mask_shape = mask.dims().to_vec();
|
||||
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask = Mask {
|
||||
iou,
|
||||
mask_shape,
|
||||
mask_data,
|
||||
};
|
||||
let json = serde_json::to_string(&mask)?;
|
||||
Ok(json)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct Mask {
|
||||
iou: f32,
|
||||
mask_shape: Vec<usize>,
|
||||
mask_data: Vec<u8>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
use candle_transformers::models::segment_anything::sam;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub use sam::{Sam, IMAGE_SIZE};
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
// Use `js_namespace` here to bind `console.log(..)` instead of just
|
||||
// `log(..)`
|
||||
#[wasm_bindgen(js_namespace = console)]
|
||||
pub fn log(s: &str);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
// Note that this is using the `log` function imported above during
|
||||
// `bare_bones`
|
||||
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
|
||||
}
|
Loading…
Reference in New Issue