mirror of https://github.com/tracel-ai/burn.git
bump candle to 0.3.1 and conv_transpose_1d (#977)
This commit is contained in:
parent
cdf54d0b40
commit
4711db0e18
|
@ -18,8 +18,8 @@ derive-new = { workspace = true }
|
|||
burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features = false }
|
||||
half = { workspace = true }
|
||||
|
||||
# TODO remove pinned version ("=") once candle-core is updated to 0.3.1
|
||||
candle-core = { version = "=0.3.0" }
|
||||
candle-core = { version = "0.3.1" }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", default-features = false, features = [
|
||||
|
|
|
@ -50,6 +50,7 @@ impl From<candle_core::Device> for CandleDevice {
|
|||
match device.location() {
|
||||
DeviceLocation::Cpu => CandleDevice::Cpu,
|
||||
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
|
||||
DeviceLocation::Metal => panic!("Metal unsupported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,7 +83,26 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
|||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self, 3> {
|
||||
panic!("Candle does not support conv_transpose1d")
|
||||
assert!(
|
||||
options.groups == 1,
|
||||
"Candle does not support groups in transposed convolutions"
|
||||
);
|
||||
let conv_transpose = x
|
||||
.tensor
|
||||
.conv_transpose1d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv_transpose
|
||||
.broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
|
||||
.unwrap(),
|
||||
None => conv_transpose,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
|
Loading…
Reference in New Issue