From 747e245cc495c45ac3a69d1d96847e4cd2170050 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Thu, 11 May 2023 14:21:52 -0400 Subject: [PATCH] chore: bump tch version (#345) --- burn-tch/Cargo.toml | 4 +-- burn-tch/src/ops/base.rs | 4 +-- burn-tch/src/ops/bool_tensor.rs | 12 ++++--- burn-tch/src/ops/int_tensor.rs | 16 +++++---- burn-tch/src/ops/module.rs | 64 ++++++++++++++++----------------- burn-tch/src/ops/tensor.rs | 16 +++++---- burn-tch/src/tensor.rs | 4 +-- 7 files changed, 63 insertions(+), 57 deletions(-) diff --git a/burn-tch/Cargo.toml b/burn-tch/Cargo.toml index 75c3a5e8b..95200c5af 100644 --- a/burn-tch/Cargo.toml +++ b/burn-tch/Cargo.toml @@ -21,12 +21,12 @@ half = {workspace = true, features = ["std"]} rand = {workspace = true, features = ["std"]} [target.'cfg(not(target_arch = "aarch64"))'.dependencies] -tch = {version = "0.11.0"} +tch = {version = "0.12.0"} # Temporary workaround for https://github.com/burn-rs/burn/issues/180 # Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least [target.'cfg(target_arch = "aarch64")'.dependencies] -tch = {version = "0.11.0", default-features = false} # Disables torch downloading +tch = {version = "0.12.0", default-features = false} # Disables torch downloading [dev-dependencies] burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [ diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index 0fa4b2778..66e1d05c1 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -15,7 +15,7 @@ impl TchOps { ) -> TchTensor { let shape_tch: TchShape = shape.into(); - TchTensor::from_existing(tensor.tensor.reshape(&shape_tch.dims), tensor.storage) + TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) } pub fn index( @@ -42,7 +42,7 @@ impl TchOps { let tensor_original = tensor.tensor.copy(); let tch_shape = TchShape::from(tensor.shape()); - let mut tensor = tensor_original.view_(&tch_shape.dims); + let mut tensor = tensor_original.view_(tch_shape.dims); for (i, index) in indexes.into_iter().enumerate().take(D2) { let start = index.start as i64; diff --git a/burn-tch/src/ops/bool_tensor.rs b/burn-tch/src/ops/bool_tensor.rs index fbe900b20..8a18efff3 100644 --- a/burn-tch/src/ops/bool_tensor.rs +++ b/burn-tch/src/ops/bool_tensor.rs @@ -19,13 +19,15 @@ impl BoolTensorOps> for TchBackend { } fn bool_to_data(tensor: &TchTensor) -> Data { - let values: Vec = tensor.tensor.shallow_clone().into(); - Data::new(values, tensor.shape()) + let shape = Self::bool_shape(tensor); + let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Data::new(values.unwrap(), shape) } fn bool_into_data(tensor: TchTensor) -> Data { - let shape = tensor.shape(); - Data::new(tensor.tensor.into(), shape) + Self::bool_to_data(&tensor) } fn bool_to_device( @@ -51,7 +53,7 @@ impl BoolTensorOps> for TchBackend { device: & as Backend>::Device, ) -> TchTensor { let tensor = tch::Tensor::empty( - &shape.dims.map(|a| a as i64), + shape.dims.map(|a| a as i64), (tch::Kind::Bool, (*device).into()), ); diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index 56403dbe1..9360179e3 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -16,13 +16,15 @@ impl IntTensorOps> for TchBackend { } fn int_to_data(tensor: &TchTensor) -> Data { - let values: Vec = tensor.tensor.shallow_clone().into(); - Data::new(values, tensor.shape()) + let shape = Self::int_shape(tensor); + let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Data::new(values.unwrap(), shape) } fn int_into_data(tensor: TchTensor) -> Data { - let shape = tensor.shape(); - Data::new(tensor.tensor.into(), shape) + Self::int_to_data(&tensor) } fn int_to_device( @@ -48,7 +50,7 @@ impl IntTensorOps> for TchBackend { device: & as Backend>::Device, ) -> TchTensor { let tensor = tch::Tensor::empty( - &shape.dims.map(|a| a as i64), + shape.dims.map(|a| a as i64), (tch::Kind::Int64, (*device).into()), ); @@ -201,7 +203,7 @@ impl IntTensorOps> for TchBackend { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); - TchTensor::new(tch::Tensor::zeros(&shape.dims, (tch::Kind::Int64, device))) + TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device))) } fn int_ones( @@ -211,7 +213,7 @@ impl IntTensorOps> for TchBackend { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); - TchTensor::new(tch::Tensor::ones(&shape.dims, (tch::Kind::Int64, device))) + TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device))) } fn int_sum(tensor: TchTensor) -> TchTensor { diff --git a/burn-tch/src/ops/module.rs b/burn-tch/src/ops/module.rs index 542fd6146..c194b2554 100644 --- a/burn-tch/src/ops/module.rs +++ b/burn-tch/src/ops/module.rs @@ -38,9 +38,9 @@ impl ModuleOps> for TchBackend { &x.tensor, &weight.tensor, bias.map(|t| t.tensor), - &options.stride.map(|i| i as i64), - &options.padding.map(|i| i as i64), - &options.dilation.map(|i| i as i64), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.dilation.map(|i| i as i64), options.groups as i64, ); @@ -57,9 +57,9 @@ impl ModuleOps> for TchBackend { &x.tensor, &weight.tensor, bias.map(|t| t.tensor), - &options.stride.map(|i| i as i64), - &options.padding.map(|i| i as i64), - &options.dilation.map(|i| i as i64), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.dilation.map(|i| i as i64), options.groups as i64, ); @@ -76,11 +76,11 @@ impl ModuleOps> for TchBackend { &x.tensor, &weight.tensor, bias.map(|t| t.tensor), - &options.stride.map(|i| i as i64), - &options.padding.map(|i| i as i64), - &options.padding_out.map(|i| i as i64), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.padding_out.map(|i| i as i64), options.groups as i64, - &options.dilation.map(|i| i as i64), + options.dilation.map(|i| i as i64), ); TchTensor::new(tensor) @@ -96,11 +96,11 @@ impl ModuleOps> for TchBackend { &x.tensor, &weight.tensor, bias.map(|t| t.tensor), - &options.stride.map(|i| i as i64), - &options.padding.map(|i| i as i64), - &options.padding_out.map(|i| i as i64), + options.stride.map(|i| i as i64), + options.padding.map(|i| i as i64), + options.padding_out.map(|i| i as i64), options.groups as i64, - &options.dilation.map(|i| i as i64), + options.dilation.map(|i| i as i64), ); TchTensor::new(tensor) @@ -114,9 +114,9 @@ impl ModuleOps> for TchBackend { ) -> TchTensor { let tensor = tch::Tensor::avg_pool2d( &x.tensor, - &[kernel_size[0] as i64, kernel_size[1] as i64], - &[stride[0] as i64, stride[1] as i64], - &[padding[0] as i64, padding[1] as i64], + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], false, true, None, @@ -135,9 +135,9 @@ impl ModuleOps> for TchBackend { let tensor = tch::Tensor::avg_pool2d_backward( &x.tensor, &grad.tensor, - &[kernel_size[0] as i64, kernel_size[1] as i64], - &[stride[0] as i64, stride[1] as i64], - &[padding[0] as i64, padding[1] as i64], + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], false, true, None, @@ -154,10 +154,10 @@ impl ModuleOps> for TchBackend { ) -> TchTensor { let tensor = tch::Tensor::max_pool2d( &x.tensor, - &[kernel_size[0] as i64, kernel_size[1] as i64], - &[stride[0] as i64, stride[1] as i64], - &[padding[0] as i64, padding[1] as i64], - &[1, 1], + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [1, 1], false, ); @@ -172,10 +172,10 @@ impl ModuleOps> for TchBackend { ) -> MaxPool2dWithIndexes> { let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices( &x.tensor, - &[kernel_size[0] as i64, kernel_size[1] as i64], - &[stride[0] as i64, stride[1] as i64], - &[padding[0] as i64, padding[1] as i64], - &[1, 1], + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [1, 1], false, ); @@ -193,10 +193,10 @@ impl ModuleOps> for TchBackend { let grad = tch::Tensor::max_pool2d_with_indices_backward( &x.tensor, &output_grad.tensor, - &[kernel_size[0] as i64, kernel_size[1] as i64], - &[stride[0] as i64, stride[1] as i64], - &[padding[0] as i64, padding[1] as i64], - &[1, 1], + [kernel_size[0] as i64, kernel_size[1] as i64], + [stride[0] as i64, stride[1] as i64], + [padding[0] as i64, padding[1] as i64], + [1, 1], false, &indexes.tensor, ); diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 758c9bf46..fb4775a8a 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -52,14 +52,14 @@ impl TensorOps> for TchBackend { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); - TchTensor::new(tch::Tensor::zeros(&shape.dims, (E::KIND, device))) + TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device))) } fn ones(shape: Shape, device: &TchDevice) -> TchTensor { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); - TchTensor::new(tch::Tensor::ones(&shape.dims, (E::KIND, device))) + TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device))) } fn shape(tensor: & as Backend>::TensorPrimitive) -> Shape { @@ -69,15 +69,17 @@ impl TensorOps> for TchBackend { fn to_data( tensor: & as Backend>::TensorPrimitive, ) -> Data< as Backend>::FloatElem, D> { - let values: Vec = tensor.tensor.shallow_clone().into(); - Data::new(values, tensor.shape()) + let shape = Self::shape(tensor); + let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()])); + let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); + + Data::new(values.unwrap(), shape) } fn into_data( tensor: as Backend>::TensorPrimitive, ) -> Data< as Backend>::FloatElem, D> { - let shape = tensor.shape(); - Data::new(tensor.tensor.into(), shape) + Self::to_data(&tensor) } fn device(tensor: &TchTensor) -> TchDevice { @@ -92,7 +94,7 @@ impl TensorOps> for TchBackend { shape: Shape, device: & as Backend>::Device, ) -> as Backend>::TensorPrimitive { - let tensor = tch::Tensor::empty(&shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); + let tensor = tch::Tensor::empty(shape.dims.map(|a| a as i64), (E::KIND, (*device).into())); TchTensor::new(tensor) } diff --git a/burn-tch/src/tensor.rs b/burn-tch/src/tensor.rs index 242457814..34435d5fd 100644 --- a/burn-tch/src/tensor.rs +++ b/burn-tch/src/tensor.rs @@ -168,7 +168,7 @@ impl TchTensor { pub fn from_data(data: Data, device: tch::Device) -> Self { let tensor = tch::Tensor::of_slice(data.value.as_slice()).to(device); let shape_tch = TchShape::from(data.shape); - let tensor = tensor.reshape(&shape_tch.dims).to_kind(E::KIND); + let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND); Self::new(tensor) } @@ -192,7 +192,7 @@ mod utils { impl TchTensor { pub fn empty(shape: Shape, device: TchDevice) -> Self { let shape_tch = TchShape::from(shape); - let tensor = tch::Tensor::empty(&shape_tch.dims, (E::KIND, device.into())); + let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into())); Self::new(tensor) }