Remove the unused pragma in vit + handle the final layernorm. (#1688)

This commit is contained in:
Laurent Mazare 2024-02-10 11:08:50 +01:00 committed by GitHub
parent 1c8d61f051
commit 67589791d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 7 deletions

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
@ -82,7 +81,7 @@ impl PatchEmbeddings {
impl Module for PatchEmbeddings {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
self.projection
.forward(pixel_values)?
.flatten_from(2)?
@ -123,9 +122,9 @@ impl Embeddings {
fn interpolate_pos_encoding(
&self,
embeddings: &Tensor,
height: usize,
width: usize,
_embeddings: &Tensor,
_height: usize,
_width: usize,
) -> Result<Tensor> {
todo!()
}
@ -136,7 +135,7 @@ impl Embeddings {
bool_masked_pos: Option<&Tensor>,
interpolate_pos_encoding: bool,
) -> Result<Tensor> {
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
let embeddings = self.patch_embeddings.forward(pixel_values)?;
let embeddings = match (bool_masked_pos, &self.mask_token) {
(None, _) => embeddings,
@ -392,6 +391,9 @@ impl Model {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(xs, None, false)?;
let encoder_outputs = self.encoder.forward(&embedding_output)?;
encoder_outputs.i((.., 0, ..))?.apply(&self.classifier)
encoder_outputs
.i((.., 0, ..))?
.apply(&self.layernorm)?
.apply(&self.classifier)
}
}