mirror of https://github.com/tracel-ai/burn.git
Fix tch view data corruption (#1434)
This commit is contained in:
parent
61c0474172
commit
2de270fe0e
|
@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc};
|
||||||
#[allow(clippy::arc_with_non_send_sync)]
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
pub type StorageRef = Arc<*mut c_void>;
|
pub type StorageRef = Arc<*mut c_void>;
|
||||||
|
|
||||||
|
/// A reference to a tensor storage.
|
||||||
|
#[derive(PartialEq, Debug, Clone)]
|
||||||
|
pub enum Storage {
|
||||||
|
/// When a tensor is a partial view of another tensor.
|
||||||
|
View {
|
||||||
|
/// Storage reference for the whole buffer.
|
||||||
|
buffer_ref: StorageRef,
|
||||||
|
/// Storage reference for the partial buffer.
|
||||||
|
view_ref: StorageRef,
|
||||||
|
},
|
||||||
|
/// When a tensor use all of its buffer.
|
||||||
|
Owned {
|
||||||
|
/// Storage reference for the whole buffer.
|
||||||
|
buffer_ref: StorageRef,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
/// Check if the storage can be used inplace.
|
||||||
|
pub fn can_mut(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Storage::View {
|
||||||
|
buffer_ref: start_ref,
|
||||||
|
view_ref,
|
||||||
|
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
|
||||||
|
Storage::Owned {
|
||||||
|
buffer_ref: start_ref,
|
||||||
|
} => Arc::strong_count(start_ref) == 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the whole buffer reference.
|
||||||
|
pub fn buffer_ref(&self) -> &StorageRef {
|
||||||
|
match self {
|
||||||
|
Storage::View {
|
||||||
|
buffer_ref: start_ref,
|
||||||
|
view_ref: _,
|
||||||
|
} => start_ref,
|
||||||
|
Storage::Owned {
|
||||||
|
buffer_ref: start_ref,
|
||||||
|
} => start_ref,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A tensor that uses the tch backend.
|
/// A tensor that uses the tch backend.
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||||
/// Handle to the tensor. Call methods on this field.
|
/// Handle to the tensor. Call methods on this field.
|
||||||
pub tensor: tch::Tensor,
|
pub tensor: tch::Tensor,
|
||||||
/// The tensor's storage
|
/// The tensor's storage
|
||||||
pub storage: StorageRef,
|
pub storage: Storage,
|
||||||
phantom: PhantomData<E>,
|
phantom: PhantomData<E>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,12 +72,14 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
||||||
/// instead.
|
/// instead.
|
||||||
pub fn new(tensor: tch::Tensor) -> Self {
|
pub fn new(tensor: tch::Tensor) -> Self {
|
||||||
#[allow(clippy::arc_with_non_send_sync)]
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
let data = Arc::new(tensor.data_ptr());
|
let storage = Storage::Owned {
|
||||||
|
buffer_ref: Arc::new(tensor.data_ptr()),
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
tensor,
|
tensor,
|
||||||
phantom: PhantomData,
|
phantom: PhantomData,
|
||||||
storage: data,
|
storage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,13 +87,34 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
||||||
///
|
///
|
||||||
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
|
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
|
||||||
/// tracking how much tensors point to the same memory space.
|
/// tracking how much tensors point to the same memory space.
|
||||||
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
|
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
|
||||||
let storage_child = tensor.data_ptr();
|
let storage_child = tensor.data_ptr();
|
||||||
|
let mut is_a_new_tensor = true;
|
||||||
|
|
||||||
#[allow(clippy::arc_with_non_send_sync)]
|
match &storage_parent {
|
||||||
let storage = match storage_child == *storage_parent {
|
Storage::View {
|
||||||
true => storage_parent.clone(),
|
buffer_ref: start_ref,
|
||||||
false => Arc::new(storage_child),
|
view_ref,
|
||||||
|
} => {
|
||||||
|
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
|
||||||
|
is_a_new_tensor = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Storage::Owned {
|
||||||
|
buffer_ref: start_ref,
|
||||||
|
} => {
|
||||||
|
if storage_child == *start_ref.as_ref() {
|
||||||
|
is_a_new_tensor = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let storage = match is_a_new_tensor {
|
||||||
|
true => Storage::Owned {
|
||||||
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
|
buffer_ref: Arc::new(storage_child),
|
||||||
|
},
|
||||||
|
false => storage_parent.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
@ -57,10 +125,15 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
|
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
|
||||||
pub fn partial(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
|
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
|
||||||
|
let storage = Storage::View {
|
||||||
|
buffer_ref: storage_parent.buffer_ref().clone(),
|
||||||
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
|
view_ref: Arc::new(tensor.data_ptr()),
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
tensor,
|
tensor,
|
||||||
storage: storage_parent,
|
storage,
|
||||||
phantom: PhantomData,
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,7 +169,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
||||||
&mut self,
|
&mut self,
|
||||||
func: F,
|
func: F,
|
||||||
) -> Option<TchTensor<EOut, D_OUT>> {
|
) -> Option<TchTensor<EOut, D_OUT>> {
|
||||||
if Arc::strong_count(&self.storage) > 1 {
|
if !self.storage.can_mut() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +186,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
||||||
FOwn: Fn(tch::Tensor) -> tch::Tensor,
|
FOwn: Fn(tch::Tensor) -> tch::Tensor,
|
||||||
FRef: Fn(&tch::Tensor) -> tch::Tensor,
|
FRef: Fn(&tch::Tensor) -> tch::Tensor,
|
||||||
{
|
{
|
||||||
if Arc::strong_count(&self.storage) > 1 {
|
if !self.storage.can_mut() {
|
||||||
return TchTensor::from_existing(fref(&self.tensor), self.storage);
|
return TchTensor::from_existing(fref(&self.tensor), self.storage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,18 @@ mod tests {
|
||||||
assert_eq!(reshaped.shape(), [4, 3].into());
|
assert_eq!(reshaped.shape(), [4, 3].into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_not_corrupt_after_slice() {
|
||||||
|
let zeros = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
|
||||||
|
zeros.clone().slice([1..2]).reshape([1]).exp();
|
||||||
|
|
||||||
|
// May lead to zeroes being equal to [0.0, 1.0]
|
||||||
|
assert_eq!(
|
||||||
|
zeros.to_data(),
|
||||||
|
Tensor::<TestBackend, 1>::zeros([2], &Default::default()).to_data()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic]
|
#[should_panic]
|
||||||
fn multiple_neg_ones() {
|
fn multiple_neg_ones() {
|
||||||
|
|
Loading…
Reference in New Issue