chore: bump tch version (#345)

This commit is contained in:
Nathaniel Simard 2023-05-11 14:21:52 -04:00 committed by GitHub
parent 05763e1878
commit 747e245cc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 57 deletions

View File

@ -21,12 +21,12 @@ half = {workspace = true, features = ["std"]}
rand = {workspace = true, features = ["std"]} rand = {workspace = true, features = ["std"]}
[target.'cfg(not(target_arch = "aarch64"))'.dependencies] [target.'cfg(not(target_arch = "aarch64"))'.dependencies]
tch = {version = "0.11.0"} tch = {version = "0.12.0"}
# Temporary workaround for https://github.com/burn-rs/burn/issues/180 # Temporary workaround for https://github.com/burn-rs/burn/issues/180
# Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least # Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least
[target.'cfg(target_arch = "aarch64")'.dependencies] [target.'cfg(target_arch = "aarch64")'.dependencies]
tch = {version = "0.11.0", default-features = false} # Disables torch downloading tch = {version = "0.12.0", default-features = false} # Disables torch downloading
[dev-dependencies] [dev-dependencies]
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [ burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [

View File

@ -15,7 +15,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
) -> TchTensor<E, D2> { ) -> TchTensor<E, D2> {
let shape_tch: TchShape<D2> = shape.into(); let shape_tch: TchShape<D2> = shape.into();
TchTensor::from_existing(tensor.tensor.reshape(&shape_tch.dims), tensor.storage) TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
} }
pub fn index<const D1: usize, const D2: usize>( pub fn index<const D1: usize, const D2: usize>(
@ -42,7 +42,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
let tensor_original = tensor.tensor.copy(); let tensor_original = tensor.tensor.copy();
let tch_shape = TchShape::from(tensor.shape()); let tch_shape = TchShape::from(tensor.shape());
let mut tensor = tensor_original.view_(&tch_shape.dims); let mut tensor = tensor_original.view_(tch_shape.dims);
for (i, index) in indexes.into_iter().enumerate().take(D2) { for (i, index) in indexes.into_iter().enumerate().take(D2) {
let start = index.start as i64; let start = index.start as i64;

View File

@ -19,13 +19,15 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
} }
fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> { fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> {
let values: Vec<bool> = tensor.tensor.shallow_clone().into(); let shape = Self::bool_shape(tensor);
Data::new(values, tensor.shape()) let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Data::new(values.unwrap(), shape)
} }
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> { fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> {
let shape = tensor.shape(); Self::bool_to_data(&tensor)
Data::new(tensor.tensor.into(), shape)
} }
fn bool_to_device<const D: usize>( fn bool_to_device<const D: usize>(
@ -51,7 +53,7 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
device: &<TchBackend<E> as Backend>::Device, device: &<TchBackend<E> as Backend>::Device,
) -> TchTensor<bool, D> { ) -> TchTensor<bool, D> {
let tensor = tch::Tensor::empty( let tensor = tch::Tensor::empty(
&shape.dims.map(|a| a as i64), shape.dims.map(|a| a as i64),
(tch::Kind::Bool, (*device).into()), (tch::Kind::Bool, (*device).into()),
); );

View File

@ -16,13 +16,15 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
} }
fn int_to_data<const D: usize>(tensor: &TchTensor<i64, D>) -> Data<i64, D> { fn int_to_data<const D: usize>(tensor: &TchTensor<i64, D>) -> Data<i64, D> {
let values: Vec<i64> = tensor.tensor.shallow_clone().into(); let shape = Self::int_shape(tensor);
Data::new(values, tensor.shape()) let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Data::new(values.unwrap(), shape)
} }
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Data<i64, D> { fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Data<i64, D> {
let shape = tensor.shape(); Self::int_to_data(&tensor)
Data::new(tensor.tensor.into(), shape)
} }
fn int_to_device<const D: usize>( fn int_to_device<const D: usize>(
@ -48,7 +50,7 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
device: &<TchBackend<E> as Backend>::Device, device: &<TchBackend<E> as Backend>::Device,
) -> TchTensor<i64, D> { ) -> TchTensor<i64, D> {
let tensor = tch::Tensor::empty( let tensor = tch::Tensor::empty(
&shape.dims.map(|a| a as i64), shape.dims.map(|a| a as i64),
(tch::Kind::Int64, (*device).into()), (tch::Kind::Int64, (*device).into()),
); );
@ -201,7 +203,7 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
let shape = TchShape::from(shape); let shape = TchShape::from(shape);
let device: tch::Device = (*device).into(); let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::zeros(&shape.dims, (tch::Kind::Int64, device))) TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device)))
} }
fn int_ones<const D: usize>( fn int_ones<const D: usize>(
@ -211,7 +213,7 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
let shape = TchShape::from(shape); let shape = TchShape::from(shape);
let device: tch::Device = (*device).into(); let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::ones(&shape.dims, (tch::Kind::Int64, device))) TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device)))
} }
fn int_sum<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> { fn int_sum<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {

View File

@ -38,9 +38,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
&x.tensor, &x.tensor,
&weight.tensor, &weight.tensor,
bias.map(|t| t.tensor), bias.map(|t| t.tensor),
&options.stride.map(|i| i as i64), options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64), options.padding.map(|i| i as i64),
&options.dilation.map(|i| i as i64), options.dilation.map(|i| i as i64),
options.groups as i64, options.groups as i64,
); );
@ -57,9 +57,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
&x.tensor, &x.tensor,
&weight.tensor, &weight.tensor,
bias.map(|t| t.tensor), bias.map(|t| t.tensor),
&options.stride.map(|i| i as i64), options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64), options.padding.map(|i| i as i64),
&options.dilation.map(|i| i as i64), options.dilation.map(|i| i as i64),
options.groups as i64, options.groups as i64,
); );
@ -76,11 +76,11 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
&x.tensor, &x.tensor,
&weight.tensor, &weight.tensor,
bias.map(|t| t.tensor), bias.map(|t| t.tensor),
&options.stride.map(|i| i as i64), options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64), options.padding.map(|i| i as i64),
&options.padding_out.map(|i| i as i64), options.padding_out.map(|i| i as i64),
options.groups as i64, options.groups as i64,
&options.dilation.map(|i| i as i64), options.dilation.map(|i| i as i64),
); );
TchTensor::new(tensor) TchTensor::new(tensor)
@ -96,11 +96,11 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
&x.tensor, &x.tensor,
&weight.tensor, &weight.tensor,
bias.map(|t| t.tensor), bias.map(|t| t.tensor),
&options.stride.map(|i| i as i64), options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64), options.padding.map(|i| i as i64),
&options.padding_out.map(|i| i as i64), options.padding_out.map(|i| i as i64),
options.groups as i64, options.groups as i64,
&options.dilation.map(|i| i as i64), options.dilation.map(|i| i as i64),
); );
TchTensor::new(tensor) TchTensor::new(tensor)
@ -114,9 +114,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
) -> TchTensor<E, 4> { ) -> TchTensor<E, 4> {
let tensor = tch::Tensor::avg_pool2d( let tensor = tch::Tensor::avg_pool2d(
&x.tensor, &x.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64], [kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64], [stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64], [padding[0] as i64, padding[1] as i64],
false, false,
true, true,
None, None,
@ -135,9 +135,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
let tensor = tch::Tensor::avg_pool2d_backward( let tensor = tch::Tensor::avg_pool2d_backward(
&x.tensor, &x.tensor,
&grad.tensor, &grad.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64], [kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64], [stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64], [padding[0] as i64, padding[1] as i64],
false, false,
true, true,
None, None,
@ -154,10 +154,10 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
) -> TchTensor<E, 4> { ) -> TchTensor<E, 4> {
let tensor = tch::Tensor::max_pool2d( let tensor = tch::Tensor::max_pool2d(
&x.tensor, &x.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64], [kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64], [stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64], [padding[0] as i64, padding[1] as i64],
&[1, 1], [1, 1],
false, false,
); );
@ -172,10 +172,10 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
) -> MaxPool2dWithIndexes<TchBackend<E>> { ) -> MaxPool2dWithIndexes<TchBackend<E>> {
let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices( let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices(
&x.tensor, &x.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64], [kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64], [stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64], [padding[0] as i64, padding[1] as i64],
&[1, 1], [1, 1],
false, false,
); );
@ -193,10 +193,10 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
let grad = tch::Tensor::max_pool2d_with_indices_backward( let grad = tch::Tensor::max_pool2d_with_indices_backward(
&x.tensor, &x.tensor,
&output_grad.tensor, &output_grad.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64], [kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64], [stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64], [padding[0] as i64, padding[1] as i64],
&[1, 1], [1, 1],
false, false,
&indexes.tensor, &indexes.tensor,
); );

View File

@ -52,14 +52,14 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
let shape = TchShape::from(shape); let shape = TchShape::from(shape);
let device: tch::Device = (*device).into(); let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::zeros(&shape.dims, (E::KIND, device))) TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device)))
} }
fn ones<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> { fn ones<const D: usize>(shape: Shape<D>, device: &TchDevice) -> TchTensor<E, D> {
let shape = TchShape::from(shape); let shape = TchShape::from(shape);
let device: tch::Device = (*device).into(); let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::ones(&shape.dims, (E::KIND, device))) TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device)))
} }
fn shape<const D: usize>(tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>) -> Shape<D> { fn shape<const D: usize>(tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>) -> Shape<D> {
@ -69,15 +69,17 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn to_data<const D: usize>( fn to_data<const D: usize>(
tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>, tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>,
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> { ) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
let values: Vec<E> = tensor.tensor.shallow_clone().into(); let shape = Self::shape(tensor);
Data::new(values, tensor.shape()) let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Data::new(values.unwrap(), shape)
} }
fn into_data<const D: usize>( fn into_data<const D: usize>(
tensor: <TchBackend<E> as Backend>::TensorPrimitive<D>, tensor: <TchBackend<E> as Backend>::TensorPrimitive<D>,
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> { ) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
let shape = tensor.shape(); Self::to_data(&tensor)
Data::new(tensor.tensor.into(), shape)
} }
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice { fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
@ -92,7 +94,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
shape: Shape<D>, shape: Shape<D>,
device: &<TchBackend<E> as Backend>::Device, device: &<TchBackend<E> as Backend>::Device,
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> { ) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let tensor = tch::Tensor::empty(&shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into()));
TchTensor::new(tensor) TchTensor::new(tensor)
} }

View File

@ -168,7 +168,7 @@ impl<E: tch::kind::Element + Default, const D: usize> TchTensor<E, D> {
pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self { pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self {
let tensor = tch::Tensor::of_slice(data.value.as_slice()).to(device); let tensor = tch::Tensor::of_slice(data.value.as_slice()).to(device);
let shape_tch = TchShape::from(data.shape); let shape_tch = TchShape::from(data.shape);
let tensor = tensor.reshape(&shape_tch.dims).to_kind(E::KIND); let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
Self::new(tensor) Self::new(tensor)
} }
@ -192,7 +192,7 @@ mod utils {
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> { impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self { pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
let shape_tch = TchShape::from(shape); let shape_tch = TchShape::from(shape);
let tensor = tch::Tensor::empty(&shape_tch.dims, (E::KIND, device.into())); let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
Self::new(tensor) Self::new(tensor)
} }