Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions. * Avoid some unnecessary groups checks. * Move the tensor convolution bits. * Properh handling of groups. * Bump the crate version. * And add a changelog.
This commit is contained in:
parent
4ee1cf038a
commit
aba1e90797
|
@ -0,0 +1,13 @@
|
|||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## Unreleased
|
||||
### Added
|
||||
- Add a group parameter to convolutions
|
||||
[566](https://github.com/huggingface/candle/pull/566).
|
||||
- New dtype: int64
|
||||
[563](https://github.com/huggingface/candle/pull/563).
|
||||
- Handling of the GGUF file format.
|
||||
[559](https://github.com/huggingface/candle/pull/559).
|
||||
|
||||
## v0.1.2 - 2023-08-21
|
|
@ -16,7 +16,7 @@ exclude = [
|
|||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
|
|
|
@ -12,7 +12,7 @@ readme = "README.md"
|
|||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.2", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.3", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
|
|
|
@ -11,7 +11,7 @@ fn main() -> Result<()> {
|
|||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1);
|
||||
let res = inp.conv2d(&w, 0, 1, 1)?;
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
|
|
|
@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
|
|||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1)
|
||||
d.0.conv1d(&d.1, 0, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
|
@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
|
|||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1)
|
||||
d.0.conv2d(&d.1, 0, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1;
|
||||
|
|
|
@ -11,7 +11,7 @@ fn main() -> Result<()> {
|
|||
let device = Device::new_cuda(0)?;
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv1D {
|
||||
pub(crate) b_size: usize,
|
||||
|
@ -51,3 +53,113 @@ impl ParamsConv2D {
|
|||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k * groups {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
|
||||
let params = ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out,
|
||||
c_in,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.map(|block| block.conv1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 2D convolution over the input tensor.
|
||||
pub fn conv2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k * groups {
|
||||
crate::bail!(
|
||||
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
|
||||
)
|
||||
}
|
||||
let params = ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.map(|block| block.conv2d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op {
|
|||
}
|
||||
|
||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: BackpropOp,
|
||||
|
@ -787,72 +787,6 @@ impl Tensor {
|
|||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let params = crate::conv::ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out,
|
||||
c_in,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = crate::conv::ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
|
@ -1920,7 +1854,7 @@ impl Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
||||
|
|
|
@ -33,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||
dev,
|
||||
)?
|
||||
.reshape((2, 4, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
|
@ -52,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.4056, -0.8689]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
|
@ -109,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||
)?;
|
||||
let t = t.reshape((1, 4, 5, 5))?;
|
||||
let w = w.reshape((2, 4, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
|
@ -143,7 +143,7 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
|||
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
||||
let t = t.reshape((1, 2, 3, 3))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
|
@ -162,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
|
|||
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
||||
let t = t.reshape((1, 1, 3, 3))?;
|
||||
let w = w.reshape((1, 1, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
|
|
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.2" }
|
||||
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.3" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
|
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.1.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.2" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.2" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.2", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.1.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.3", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
|
|
@ -274,14 +274,22 @@ impl EncodecConv1d {
|
|||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
|
|
|
@ -66,6 +66,7 @@ impl ResnetBlock2D {
|
|||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
groups: 1,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
|
@ -79,6 +80,7 @@ impl ResnetBlock2D {
|
|||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 0,
|
||||
groups: 1,
|
||||
};
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
|
|
|
@ -112,8 +112,8 @@ impl UNet2DConditionModel {
|
|||
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
|
||||
let time_embed_dim = b_channels * 4;
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
|
||||
|
||||
|
|
|
@ -24,7 +24,11 @@ impl Downsample2D {
|
|||
padding: usize,
|
||||
) -> Result<Self> {
|
||||
let conv = if use_conv {
|
||||
let config = nn::Conv2dConfig { stride: 2, padding };
|
||||
let config = nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
Some(conv)
|
||||
} else {
|
||||
|
|
|
@ -51,8 +51,8 @@ impl Encoder {
|
|||
config: EncoderConfig,
|
||||
) -> Result<Self> {
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_in = nn::conv2d(
|
||||
in_channels,
|
||||
|
@ -182,8 +182,8 @@ impl Decoder {
|
|||
let n_block_out_channels = config.block_out_channels.len();
|
||||
let last_block_out_channels = *config.block_out_channels.last().unwrap();
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_in = nn::conv2d(
|
||||
in_channels,
|
||||
|
|
|
@ -308,10 +308,12 @@ impl AudioEncoder {
|
|||
let cfg1 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
|
|
@ -128,7 +128,11 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||
}
|
||||
Some(_) | None => (None, true),
|
||||
};
|
||||
let conv_cfg = candle_nn::Conv2dConfig { stride, padding };
|
||||
let conv_cfg = candle_nn::Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
groups: 1,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
||||
} else {
|
||||
|
|
|
@ -101,7 +101,11 @@ impl ConvBlock {
|
|||
padding: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let padding = padding.unwrap_or(k / 2);
|
||||
let cfg = Conv2dConfig { padding, stride };
|
||||
let cfg = Conv2dConfig {
|
||||
padding,
|
||||
stride,
|
||||
groups: 1,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
Ok(Self { conv, bn })
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
|
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.2", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.3", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
|||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.2", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.3", features = ["cuda"] }
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
|
|
@ -11,7 +11,7 @@ readme = "README.md"
|
|||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
safetensors = { workspace = true }
|
||||
|
|
|
@ -5,6 +5,7 @@ use candle::{Result, Tensor};
|
|||
pub struct Conv1dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
|
@ -12,6 +13,7 @@ impl Default for Conv1dConfig {
|
|||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +41,12 @@ impl Conv1d {
|
|||
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
let x = x.conv1d(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
|
@ -55,6 +62,7 @@ impl crate::Module for Conv1d {
|
|||
pub struct Conv2dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
|
@ -62,6 +70,7 @@ impl Default for Conv2dConfig {
|
|||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +99,12 @@ impl Conv2d {
|
|||
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
let x = x.conv2d(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
|
|
|
@ -15,7 +15,7 @@ crate-type = ["cdylib"]
|
|||
doc = false
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.2" }
|
||||
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.3" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
rand = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
|
|
@ -9,8 +9,8 @@ categories.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.2" }
|
||||
candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.3" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ categories.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.2" }
|
||||
candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.3" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
|
|
@ -295,10 +295,12 @@ impl AudioEncoder {
|
|||
let cfg1 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
|
|
@ -9,8 +9,8 @@ categories.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.2" }
|
||||
candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.3" }
|
||||
num-traits = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
|
|
@ -97,7 +97,11 @@ impl ConvBlock {
|
|||
padding: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let padding = padding.unwrap_or(k / 2);
|
||||
let cfg = Conv2dConfig { padding, stride };
|
||||
let cfg = Conv2dConfig {
|
||||
padding,
|
||||
stride,
|
||||
groups: 1,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
Ok(Self { conv, bn })
|
||||
|
|
Loading…
Reference in New Issue