diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index 49bfc9f40..7bd9ee747 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc}; #[allow(clippy::arc_with_non_send_sync)] pub type StorageRef = Arc<*mut c_void>; +/// A reference to a tensor storage. +#[derive(PartialEq, Debug, Clone)] +pub enum Storage { + /// When a tensor is a partial view of another tensor. + View { + /// Storage reference for the whole buffer. + buffer_ref: StorageRef, + /// Storage reference for the partial buffer. + view_ref: StorageRef, + }, + /// When a tensor use all of its buffer. + Owned { + /// Storage reference for the whole buffer. + buffer_ref: StorageRef, + }, +} + +impl Storage { + /// Check if the storage can be used inplace. + pub fn can_mut(&self) -> bool { + match self { + Storage::View { + buffer_ref: start_ref, + view_ref, + } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1, + Storage::Owned { + buffer_ref: start_ref, + } => Arc::strong_count(start_ref) == 1, + } + } + + /// Get the whole buffer reference. + pub fn buffer_ref(&self) -> &StorageRef { + match self { + Storage::View { + buffer_ref: start_ref, + view_ref: _, + } => start_ref, + Storage::Owned { + buffer_ref: start_ref, + } => start_ref, + } + } +} + /// A tensor that uses the tch backend. #[derive(Debug, PartialEq)] pub struct TchTensor { /// Handle to the tensor. Call methods on this field. pub tensor: tch::Tensor, /// The tensor's storage - pub storage: StorageRef, + pub storage: Storage, phantom: PhantomData, } @@ -27,12 +72,14 @@ impl TchTensor { /// instead. pub fn new(tensor: tch::Tensor) -> Self { #[allow(clippy::arc_with_non_send_sync)] - let data = Arc::new(tensor.data_ptr()); + let storage = Storage::Owned { + buffer_ref: Arc::new(tensor.data_ptr()), + }; Self { tensor, phantom: PhantomData, - storage: data, + storage, } } @@ -40,13 +87,34 @@ impl TchTensor { /// /// If the child tensor shared the same storage as its parent, it will be cloned, effectively /// tracking how much tensors point to the same memory space. - pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self { + pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self { let storage_child = tensor.data_ptr(); + let mut is_a_new_tensor = true; - #[allow(clippy::arc_with_non_send_sync)] - let storage = match storage_child == *storage_parent { - true => storage_parent.clone(), - false => Arc::new(storage_child), + match &storage_parent { + Storage::View { + buffer_ref: start_ref, + view_ref, + } => { + if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() { + is_a_new_tensor = false; + } + } + Storage::Owned { + buffer_ref: start_ref, + } => { + if storage_child == *start_ref.as_ref() { + is_a_new_tensor = false; + } + } + }; + + let storage = match is_a_new_tensor { + true => Storage::Owned { + #[allow(clippy::arc_with_non_send_sync)] + buffer_ref: Arc::new(storage_child), + }, + false => storage_parent.clone(), }; Self { @@ -57,10 +125,15 @@ impl TchTensor { } /// Create a tensor that uses a part of its parent tensor such as slice and narrow. - pub fn partial(tensor: tch::Tensor, storage_parent: StorageRef) -> Self { + pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self { + let storage = Storage::View { + buffer_ref: storage_parent.buffer_ref().clone(), + #[allow(clippy::arc_with_non_send_sync)] + view_ref: Arc::new(tensor.data_ptr()), + }; Self { tensor, - storage: storage_parent, + storage, phantom: PhantomData, } } @@ -96,7 +169,7 @@ impl TchTensor { &mut self, func: F, ) -> Option> { - if Arc::strong_count(&self.storage) > 1 { + if !self.storage.can_mut() { return None; } @@ -113,7 +186,7 @@ impl TchTensor { FOwn: Fn(tch::Tensor) -> tch::Tensor, FRef: Fn(&tch::Tensor) -> tch::Tensor, { - if Arc::strong_count(&self.storage) > 1 { + if !self.storage.can_mut() { return TchTensor::from_existing(fref(&self.tensor), self.storage); } diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index cc786dc2a..0e1c06f74 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -70,6 +70,18 @@ mod tests { assert_eq!(reshaped.shape(), [4, 3].into()); } + #[test] + fn should_not_corrupt_after_slice() { + let zeros = Tensor::::zeros([2], &Default::default()); + zeros.clone().slice([1..2]).reshape([1]).exp(); + + // May lead to zeroes being equal to [0.0, 1.0] + assert_eq!( + zeros.to_data(), + Tensor::::zeros([2], &Default::default()).to_data() + ); + } + #[test] #[should_panic] fn multiple_neg_ones() {