Implement group-norm. (#334)
* Implement group-norm. * Add some testing for group-norm.
This commit is contained in:
parent
2c9f605976
commit
5bb2fce998
|
@ -6,6 +6,7 @@
|
|||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
use candle::{Device, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Activation {
|
||||
|
@ -16,7 +17,7 @@ pub enum Activation {
|
|||
impl Activation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?,
|
||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||
Activation::Gelu => xs.gelu(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Add sigmoid as binary ops.
|
||||
(xs.neg()?.exp()? - 1.0)?.recip()
|
||||
}
|
||||
|
||||
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
//! Group Normalization.
|
||||
//!
|
||||
//! This layer applies Group Normalization over a mini-batch of inputs.
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
||||
// This group norm version handles both weight and bias so removes the mean.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub struct GroupNorm {
|
||||
weight: Tensor,
|
||||
|
@ -21,18 +20,50 @@ impl GroupNorm {
|
|||
num_channels: usize,
|
||||
num_groups: usize,
|
||||
eps: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
) -> Result<Self> {
|
||||
if num_channels % num_groups != 0 {
|
||||
candle::bail!(
|
||||
"GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
|
||||
)
|
||||
}
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
num_channels,
|
||||
num_groups,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, _: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_shape = x.dims();
|
||||
if x_shape.len() <= 2 {
|
||||
candle::bail!("input rank for GroupNorm should be at least 3");
|
||||
}
|
||||
let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
|
||||
let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
|
||||
if n_channels != self.num_channels {
|
||||
candle::bail!(
|
||||
"unexpected num-channels in GroupNorm ({n_channels} <> {}",
|
||||
self.num_channels
|
||||
)
|
||||
}
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?
|
||||
.reshape(x_shape)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,5 +75,5 @@ pub fn group_norm(
|
|||
) -> Result<GroupNorm> {
|
||||
let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
|
||||
Ok(GroupNorm::new(weight, bias, num_channels, num_groups, eps))
|
||||
GroupNorm::new(weight, bias, num_channels, num_groups, eps)
|
||||
}
|
||||
|
|
|
@ -34,5 +34,11 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
|||
}
|
||||
|
||||
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/* Equivalent PyTorch code.
|
||||
import torch
|
||||
from torch.nn.functional import group_norm
|
||||
t = torch.tensor(
|
||||
[[[-0.3034, 0.2726, -0.9659],
|
||||
[-1.1845, -1.3236, 0.0172],
|
||||
[ 1.9507, 1.2554, -0.8625],
|
||||
[ 1.0682, 0.3604, 0.3985],
|
||||
[-0.4957, -0.4461, -0.9721],
|
||||
[ 1.5157, -0.1546, -0.5596]],
|
||||
|
||||
[[-1.6698, -0.4040, -0.7927],
|
||||
[ 0.3736, -0.0975, -0.1351],
|
||||
[-0.9461, 0.5461, -0.6334],
|
||||
[-1.0919, -0.1158, 0.1213],
|
||||
[-0.9535, 0.1281, 0.4372],
|
||||
[-0.2845, 0.3488, 0.5641]]])
|
||||
print(group_norm(t, num_groups=2))
|
||||
print(group_norm(t, num_groups=3))
|
||||
*/
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::GroupNorm;
|
||||
mod test_utils;
|
||||
use test_utils::to_vec3_round;
|
||||
|
||||
#[test]
|
||||
fn group_norm() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
let w = Tensor::new(&[1f32], device)?;
|
||||
let b = Tensor::new(&[0f32], device)?;
|
||||
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
|
||||
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;
|
||||
|
||||
let input = Tensor::new(
|
||||
&[
|
||||
[
|
||||
[-0.3034f32, 0.2726, -0.9659],
|
||||
[-1.1845, -1.3236, 0.0172],
|
||||
[1.9507, 1.2554, -0.8625],
|
||||
[1.0682, 0.3604, 0.3985],
|
||||
[-0.4957, -0.4461, -0.9721],
|
||||
[1.5157, -0.1546, -0.5596],
|
||||
],
|
||||
[
|
||||
[-1.6698, -0.4040, -0.7927],
|
||||
[0.3736, -0.0975, -0.1351],
|
||||
[-0.9461, 0.5461, -0.6334],
|
||||
[-1.0919, -0.1158, 0.1213],
|
||||
[-0.9535, 0.1281, 0.4372],
|
||||
[-0.2845, 0.3488, 0.5641],
|
||||
],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(gn2.forward(&input)?, 4)?,
|
||||
&[
|
||||
[
|
||||
[-0.1653, 0.3748, -0.7866],
|
||||
[-0.9916, -1.1220, 0.1353],
|
||||
[1.9485, 1.2965, -0.6896],
|
||||
[1.2769, 0.3628, 0.4120],
|
||||
[-0.7427, -0.6786, -1.3578],
|
||||
[1.8547, -0.3022, -0.8252]
|
||||
],
|
||||
[
|
||||
[-1.9342, 0.0211, -0.5793],
|
||||
[1.2223, 0.4945, 0.4365],
|
||||
[-0.8163, 1.4887, -0.3333],
|
||||
[-1.7960, -0.0392, 0.3875],
|
||||
[-1.5469, 0.3998, 0.9561],
|
||||
[-0.3428, 0.7970, 1.1845]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(gn3.forward(&input)?, 4)?,
|
||||
&[
|
||||
[
|
||||
[0.4560, 1.4014, -0.6313],
|
||||
[-0.9901, -1.2184, 0.9822],
|
||||
[1.4254, 0.6360, -1.7682],
|
||||
[0.4235, -0.3800, -0.3367],
|
||||
[-0.3890, -0.3268, -0.9862],
|
||||
[2.1325, 0.0386, -0.4691]
|
||||
],
|
||||
[
|
||||
[-1.8797, 0.0777, -0.5234],
|
||||
[1.2802, 0.5517, 0.4935],
|
||||
[-1.0102, 1.5327, -0.4773],
|
||||
[-1.2587, 0.4047, 0.8088],
|
||||
[-1.9074, 0.1691, 0.7625],
|
||||
[-0.6230, 0.5928, 1.0061]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue