Fix tch view data corruption (#1434)

This commit is contained in:
Nathaniel Simard 2024-03-08 09:55:47 -05:00 committed by GitHub
parent 61c0474172
commit 2de270fe0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 97 additions and 12 deletions

View File

@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc};
#[allow(clippy::arc_with_non_send_sync)]
pub type StorageRef = Arc<*mut c_void>;
/// A reference to a tensor storage.
#[derive(PartialEq, Debug, Clone)]
pub enum Storage {
/// When a tensor is a partial view of another tensor.
View {
/// Storage reference for the whole buffer.
buffer_ref: StorageRef,
/// Storage reference for the partial buffer.
view_ref: StorageRef,
},
/// When a tensor use all of its buffer.
Owned {
/// Storage reference for the whole buffer.
buffer_ref: StorageRef,
},
}
impl Storage {
/// Check if the storage can be used inplace.
pub fn can_mut(&self) -> bool {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
Storage::Owned {
buffer_ref: start_ref,
} => Arc::strong_count(start_ref) == 1,
}
}
/// Get the whole buffer reference.
pub fn buffer_ref(&self) -> &StorageRef {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref: _,
} => start_ref,
Storage::Owned {
buffer_ref: start_ref,
} => start_ref,
}
}
}
/// A tensor that uses the tch backend.
#[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
/// Handle to the tensor. Call methods on this field.
pub tensor: tch::Tensor,
/// The tensor's storage
pub storage: StorageRef,
pub storage: Storage,
phantom: PhantomData<E>,
}
@ -27,12 +72,14 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
/// instead.
pub fn new(tensor: tch::Tensor) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
let data = Arc::new(tensor.data_ptr());
let storage = Storage::Owned {
buffer_ref: Arc::new(tensor.data_ptr()),
};
Self {
tensor,
phantom: PhantomData,
storage: data,
storage,
}
}
@ -40,13 +87,34 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
///
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
/// tracking how much tensors point to the same memory space.
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage_child = tensor.data_ptr();
let mut is_a_new_tensor = true;
#[allow(clippy::arc_with_non_send_sync)]
let storage = match storage_child == *storage_parent {
true => storage_parent.clone(),
false => Arc::new(storage_child),
match &storage_parent {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => {
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
is_a_new_tensor = false;
}
}
Storage::Owned {
buffer_ref: start_ref,
} => {
if storage_child == *start_ref.as_ref() {
is_a_new_tensor = false;
}
}
};
let storage = match is_a_new_tensor {
true => Storage::Owned {
#[allow(clippy::arc_with_non_send_sync)]
buffer_ref: Arc::new(storage_child),
},
false => storage_parent.clone(),
};
Self {
@ -57,10 +125,15 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
}
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
pub fn partial(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage = Storage::View {
buffer_ref: storage_parent.buffer_ref().clone(),
#[allow(clippy::arc_with_non_send_sync)]
view_ref: Arc::new(tensor.data_ptr()),
};
Self {
tensor,
storage: storage_parent,
storage,
phantom: PhantomData,
}
}
@ -96,7 +169,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
&mut self,
func: F,
) -> Option<TchTensor<EOut, D_OUT>> {
if Arc::strong_count(&self.storage) > 1 {
if !self.storage.can_mut() {
return None;
}
@ -113,7 +186,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
if Arc::strong_count(&self.storage) > 1 {
if !self.storage.can_mut() {
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}

View File

@ -70,6 +70,18 @@ mod tests {
assert_eq!(reshaped.shape(), [4, 3].into());
}
#[test]
fn should_not_corrupt_after_slice() {
let zeros = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
zeros.clone().slice([1..2]).reshape([1]).exp();
// May lead to zeroes being equal to [0.0, 1.0]
assert_eq!(
zeros.to_data(),
Tensor::<TestBackend, 1>::zeros([2], &Default::default()).to_data()
);
}
#[test]
#[should_panic]
fn multiple_neg_ones() {