Add the alloc_uninit function. (#1901)
* Add the alloc_uninit function. * Dummy metal fix. * Lazy initialization.
This commit is contained in:
parent
a00e24d752
commit
6708870e63
|
@ -127,6 +127,12 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
/// after this call.
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||
|
|
|
@ -2582,7 +2582,10 @@ impl BackendStorage for CpuStorage {
|
|||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
|
@ -2590,7 +2593,7 @@ impl BackendStorage for CpuStorage {
|
|||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
@ -2681,7 +2684,10 @@ impl BackendStorage for CpuStorage {
|
|||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
|
@ -2691,7 +2697,7 @@ impl BackendStorage for CpuStorage {
|
|||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
@ -2919,6 +2925,53 @@ impl BackendDevice for CpuDevice {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
// The code below is highly unsafe but hopefully not directly unsound as we only consider
|
||||
// types that are Copy, not Drop, and for which all bit patterns are proper values.
|
||||
// It's still pretty risky, see the following for more details:
|
||||
// https://github.com/rust-lang/rust-clippy/issues/4483
|
||||
let storage = match dtype {
|
||||
DType::U8 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::U8(v)
|
||||
}
|
||||
DType::U32 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::U32(v)
|
||||
}
|
||||
DType::I64 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::I64(v)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::BF16(v)
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::F16(v)
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::F32(v)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::F64(v)
|
||||
}
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
|
|
|
@ -384,6 +384,44 @@ impl BackendDevice for CudaDevice {
|
|||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
|
@ -1916,7 +1954,10 @@ impl BackendStorage for CudaStorage {
|
|||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
|
@ -1924,7 +1965,7 @@ impl BackendStorage for CudaStorage {
|
|||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
@ -1981,7 +2022,10 @@ impl BackendStorage for CudaStorage {
|
|||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
|
@ -1991,7 +2035,7 @@ impl BackendStorage for CudaStorage {
|
|||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
@ -2128,7 +2172,7 @@ impl BackendStorage for CudaStorage {
|
|||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
|
@ -2143,7 +2187,7 @@ impl BackendStorage for CudaStorage {
|
|||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
|
|
|
@ -289,6 +289,23 @@ impl Device {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
|
|
|
@ -210,6 +210,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
|
|
@ -222,6 +222,10 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
|
|
@ -1886,6 +1886,16 @@ impl BackendDevice for MetalDevice {
|
|||
self.device.registry_id() == rhs.device.registry_id()
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?;
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let size = shape.elem_count() * dtype.size_in_bytes();
|
||||
let buffer = self.allocate_zeros(size)?;
|
||||
|
|
|
@ -1349,7 +1349,7 @@ impl Tensor {
|
|||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
|
@ -1999,7 +1999,7 @@ impl Tensor {
|
|||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
|
@ -2011,7 +2011,7 @@ impl Tensor {
|
|||
/// copied.
|
||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||
let shape = self.shape().clone();
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
|
@ -2064,7 +2064,7 @@ impl Tensor {
|
|||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
|
|
|
@ -141,7 +141,7 @@ impl Tensor {
|
|||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
|
@ -215,7 +215,7 @@ impl Tensor {
|
|||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let mut dst_o = 0;
|
||||
for arg in args.iter() {
|
||||
let arg = arg.as_ref();
|
||||
|
|
Loading…
Reference in New Issue