mirror of https://github.com/tracel-ai/burn.git
chore: bump tch version (#345)
This commit is contained in:
parent
05763e1878
commit
747e245cc4
|
@ -21,12 +21,12 @@ half = {workspace = true, features = ["std"]}
|
|||
rand = {workspace = true, features = ["std"]}
|
||||
|
||||
[target.'cfg(not(target_arch = "aarch64"))'.dependencies]
|
||||
tch = {version = "0.11.0"}
|
||||
tch = {version = "0.12.0"}
|
||||
|
||||
# Temporary workaround for https://github.com/burn-rs/burn/issues/180
|
||||
# Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least
|
||||
[target.'cfg(target_arch = "aarch64")'.dependencies]
|
||||
tch = {version = "0.11.0", default-features = false} # Disables torch downloading
|
||||
tch = {version = "0.12.0", default-features = false} # Disables torch downloading
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [
|
||||
|
|
|
@ -15,7 +15,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
) -> TchTensor<E, D2> {
|
||||
let shape_tch: TchShape<D2> = shape.into();
|
||||
|
||||
TchTensor::from_existing(tensor.tensor.reshape(&shape_tch.dims), tensor.storage)
|
||||
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
|
||||
}
|
||||
|
||||
pub fn index<const D1: usize, const D2: usize>(
|
||||
|
@ -42,7 +42,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
let tensor_original = tensor.tensor.copy();
|
||||
let tch_shape = TchShape::from(tensor.shape());
|
||||
|
||||
let mut tensor = tensor_original.view_(&tch_shape.dims);
|
||||
let mut tensor = tensor_original.view_(tch_shape.dims);
|
||||
|
||||
for (i, index) in indexes.into_iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
|
|
|
@ -19,13 +19,15 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> {
|
||||
let values: Vec<bool> = tensor.tensor.shallow_clone().into();
|
||||
Data::new(values, tensor.shape())
|
||||
let shape = Self::bool_shape(tensor);
|
||||
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Data::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> {
|
||||
let shape = tensor.shape();
|
||||
Data::new(tensor.tensor.into(), shape)
|
||||
Self::bool_to_data(&tensor)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
@ -51,7 +53,7 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
device: &<TchBackend<E> as Backend>::Device,
|
||||
) -> TchTensor<bool, D> {
|
||||
let tensor = tch::Tensor::empty(
|
||||
&shape.dims.map(|a| a as i64),
|
||||
shape.dims.map(|a| a as i64),
|
||||
(tch::Kind::Bool, (*device).into()),
|
||||
);
|
||||
|
||||
|
|
|
@ -16,13 +16,15 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
}
|
||||
|
||||
fn int_to_data<const D: usize>(tensor: &TchTensor<i64, D>) -> Data<i64, D> {
|
||||
let values: Vec<i64> = tensor.tensor.shallow_clone().into();
|
||||
Data::new(values, tensor.shape())
|
||||
let shape = Self::int_shape(tensor);
|
||||
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Data::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Data<i64, D> {
|
||||
let shape = tensor.shape();
|
||||
Data::new(tensor.tensor.into(), shape)
|
||||
Self::int_to_data(&tensor)
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
@ -48,7 +50,7 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
device: &<TchBackend<E> as Backend>::Device,
|
||||
) -> TchTensor<i64, D> {
|
||||
let tensor = tch::Tensor::empty(
|
||||
&shape.dims.map(|a| a as i64),
|
||||
shape.dims.map(|a| a as i64),
|
||||
(tch::Kind::Int64, (*device).into()),
|
||||
);
|
||||
|
||||
|
@ -201,7 +203,7 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::zeros(&shape.dims, (tch::Kind::Int64, device)))
|
||||
TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device)))
|
||||
}
|
||||
|
||||
fn int_ones<const D: usize>(
|
||||
|
@ -211,7 +213,7 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::ones(&shape.dims, (tch::Kind::Int64, device)))
|
||||
TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device)))
|
||||
}
|
||||
|
||||
fn int_sum<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
|
||||
|
|
|
@ -38,9 +38,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.dilation.map(|i| i as i64),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
|
@ -57,9 +57,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.dilation.map(|i| i as i64),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
|
@ -76,11 +76,11 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.padding_out.map(|i| i as i64),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
&options.dilation.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
@ -96,11 +96,11 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.padding_out.map(|i| i as i64),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
&options.dilation.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
@ -114,9 +114,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::avg_pool2d(
|
||||
&x.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
false,
|
||||
true,
|
||||
None,
|
||||
|
@ -135,9 +135,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
let tensor = tch::Tensor::avg_pool2d_backward(
|
||||
&x.tensor,
|
||||
&grad.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
false,
|
||||
true,
|
||||
None,
|
||||
|
@ -154,10 +154,10 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::max_pool2d(
|
||||
&x.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
false,
|
||||
);
|
||||
|
||||
|
@ -172,10 +172,10 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
) -> MaxPool2dWithIndexes<TchBackend<E>> {
|
||||
let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices(
|
||||
&x.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
false,
|
||||
);
|
||||
|
||||
|
@ -193,10 +193,10 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
let grad = tch::Tensor::max_pool2d_with_indices_backward(
|
||||
&x.tensor,
|
||||
&output_grad.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
false,
|
||||
&indexes.tensor,
|
||||
);
|
||||
|
|
|
@ -52,14 +52,14 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::zeros(&shape.dims, (E::KIND, device)))
|
||||
TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device)))
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> {
|
||||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::ones(&shape.dims, (E::KIND, device)))
|
||||
TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device)))
|
||||
}
|
||||
|
||||
fn shape<const D: usize>(tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>) -> Shape<D> {
|
||||
|
@ -69,15 +69,17 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn to_data<const D: usize>(
|
||||
tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
|
||||
let values: Vec<E> = tensor.tensor.shallow_clone().into();
|
||||
Data::new(values, tensor.shape())
|
||||
let shape = Self::shape(tensor);
|
||||
let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Data::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(
|
||||
tensor: <TchBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
|
||||
let shape = tensor.shape();
|
||||
Data::new(tensor.tensor.into(), shape)
|
||||
Self::to_data(&tensor)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
|
||||
|
@ -92,7 +94,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
shape: Shape<D>,
|
||||
device: &<TchBackend<E> as Backend>::Device,
|
||||
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
let tensor = tch::Tensor::empty(&shape.dims.map(|a| a as i64), (E::KIND, (*device).into()));
|
||||
let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into()));
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
|
|
@ -168,7 +168,7 @@ impl<E: tch::kind::Element + Default, const D: usize> TchTensor<E, D> {
|
|||
pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self {
|
||||
let tensor = tch::Tensor::of_slice(data.value.as_slice()).to(device);
|
||||
let shape_tch = TchShape::from(data.shape);
|
||||
let tensor = tensor.reshape(&shape_tch.dims).to_kind(E::KIND);
|
||||
let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
|
||||
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ mod utils {
|
|||
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
|
||||
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
|
||||
let shape_tch = TchShape::from(shape);
|
||||
let tensor = tch::Tensor::empty(&shape_tch.dims, (E::KIND, device.into()));
|
||||
let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
|
||||
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue