mirror of https://github.com/tracel-ai/burn.git
Fix tch view data corruption (#1434)
This commit is contained in:
parent
61c0474172
commit
2de270fe0e
|
@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc};
|
|||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
pub type StorageRef = Arc<*mut c_void>;
|
||||
|
||||
/// A reference to a tensor storage.
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub enum Storage {
|
||||
/// When a tensor is a partial view of another tensor.
|
||||
View {
|
||||
/// Storage reference for the whole buffer.
|
||||
buffer_ref: StorageRef,
|
||||
/// Storage reference for the partial buffer.
|
||||
view_ref: StorageRef,
|
||||
},
|
||||
/// When a tensor use all of its buffer.
|
||||
Owned {
|
||||
/// Storage reference for the whole buffer.
|
||||
buffer_ref: StorageRef,
|
||||
},
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
/// Check if the storage can be used inplace.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
match self {
|
||||
Storage::View {
|
||||
buffer_ref: start_ref,
|
||||
view_ref,
|
||||
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
|
||||
Storage::Owned {
|
||||
buffer_ref: start_ref,
|
||||
} => Arc::strong_count(start_ref) == 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the whole buffer reference.
|
||||
pub fn buffer_ref(&self) -> &StorageRef {
|
||||
match self {
|
||||
Storage::View {
|
||||
buffer_ref: start_ref,
|
||||
view_ref: _,
|
||||
} => start_ref,
|
||||
Storage::Owned {
|
||||
buffer_ref: start_ref,
|
||||
} => start_ref,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A tensor that uses the tch backend.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||
/// Handle to the tensor. Call methods on this field.
|
||||
pub tensor: tch::Tensor,
|
||||
/// The tensor's storage
|
||||
pub storage: StorageRef,
|
||||
pub storage: Storage,
|
||||
phantom: PhantomData<E>,
|
||||
}
|
||||
|
||||
|
@ -27,12 +72,14 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
/// instead.
|
||||
pub fn new(tensor: tch::Tensor) -> Self {
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let data = Arc::new(tensor.data_ptr());
|
||||
let storage = Storage::Owned {
|
||||
buffer_ref: Arc::new(tensor.data_ptr()),
|
||||
};
|
||||
|
||||
Self {
|
||||
tensor,
|
||||
phantom: PhantomData,
|
||||
storage: data,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,13 +87,34 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
///
|
||||
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
|
||||
/// tracking how much tensors point to the same memory space.
|
||||
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
|
||||
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
|
||||
let storage_child = tensor.data_ptr();
|
||||
let mut is_a_new_tensor = true;
|
||||
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let storage = match storage_child == *storage_parent {
|
||||
true => storage_parent.clone(),
|
||||
false => Arc::new(storage_child),
|
||||
match &storage_parent {
|
||||
Storage::View {
|
||||
buffer_ref: start_ref,
|
||||
view_ref,
|
||||
} => {
|
||||
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
|
||||
is_a_new_tensor = false;
|
||||
}
|
||||
}
|
||||
Storage::Owned {
|
||||
buffer_ref: start_ref,
|
||||
} => {
|
||||
if storage_child == *start_ref.as_ref() {
|
||||
is_a_new_tensor = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let storage = match is_a_new_tensor {
|
||||
true => Storage::Owned {
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
buffer_ref: Arc::new(storage_child),
|
||||
},
|
||||
false => storage_parent.clone(),
|
||||
};
|
||||
|
||||
Self {
|
||||
|
@ -57,10 +125,15 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
}
|
||||
|
||||
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
|
||||
pub fn partial(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
|
||||
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
|
||||
let storage = Storage::View {
|
||||
buffer_ref: storage_parent.buffer_ref().clone(),
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
view_ref: Arc::new(tensor.data_ptr()),
|
||||
};
|
||||
Self {
|
||||
tensor,
|
||||
storage: storage_parent,
|
||||
storage,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +169,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
&mut self,
|
||||
func: F,
|
||||
) -> Option<TchTensor<EOut, D_OUT>> {
|
||||
if Arc::strong_count(&self.storage) > 1 {
|
||||
if !self.storage.can_mut() {
|
||||
return None;
|
||||
}
|
||||
|
||||
|
@ -113,7 +186,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
|||
FOwn: Fn(tch::Tensor) -> tch::Tensor,
|
||||
FRef: Fn(&tch::Tensor) -> tch::Tensor,
|
||||
{
|
||||
if Arc::strong_count(&self.storage) > 1 {
|
||||
if !self.storage.can_mut() {
|
||||
return TchTensor::from_existing(fref(&self.tensor), self.storage);
|
||||
}
|
||||
|
||||
|
|
|
@ -70,6 +70,18 @@ mod tests {
|
|||
assert_eq!(reshaped.shape(), [4, 3].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_corrupt_after_slice() {
|
||||
let zeros = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
|
||||
zeros.clone().slice([1..2]).reshape([1]).exp();
|
||||
|
||||
// May lead to zeroes being equal to [0.0, 1.0]
|
||||
assert_eq!(
|
||||
zeros.to_data(),
|
||||
Tensor::<TestBackend, 1>::zeros([2], &Default::default()).to_data()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn multiple_neg_ones() {
|
||||
|
|
Loading…
Reference in New Issue