Remove the unused pragma in vit + handle the final layernorm. (#1688)
This commit is contained in:
parent
1c8d61f051
commit
67589791d2
|
@ -1,4 +1,3 @@
|
|||
#![allow(unused)]
|
||||
use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
@ -82,7 +81,7 @@ impl PatchEmbeddings {
|
|||
|
||||
impl Module for PatchEmbeddings {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
||||
let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
|
||||
self.projection
|
||||
.forward(pixel_values)?
|
||||
.flatten_from(2)?
|
||||
|
@ -123,9 +122,9 @@ impl Embeddings {
|
|||
|
||||
fn interpolate_pos_encoding(
|
||||
&self,
|
||||
embeddings: &Tensor,
|
||||
height: usize,
|
||||
width: usize,
|
||||
_embeddings: &Tensor,
|
||||
_height: usize,
|
||||
_width: usize,
|
||||
) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -136,7 +135,7 @@ impl Embeddings {
|
|||
bool_masked_pos: Option<&Tensor>,
|
||||
interpolate_pos_encoding: bool,
|
||||
) -> Result<Tensor> {
|
||||
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
||||
let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
|
||||
let embeddings = self.patch_embeddings.forward(pixel_values)?;
|
||||
let embeddings = match (bool_masked_pos, &self.mask_token) {
|
||||
(None, _) => embeddings,
|
||||
|
@ -392,6 +391,9 @@ impl Model {
|
|||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||
encoder_outputs.i((.., 0, ..))?.apply(&self.classifier)
|
||||
encoder_outputs
|
||||
.i((.., 0, ..))?
|
||||
.apply(&self.layernorm)?
|
||||
.apply(&self.classifier)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue