mirror of https://github.com/tracel-ai/burn.git
Add missing docs for burn-wgpu and burn-tch (#427)
This commit is contained in:
parent
825aaa9977
commit
f0b266c8a3
|
@ -19,9 +19,17 @@ use burn_tensor::backend::Backend;
|
|||
/// let device_vulkan = TchDevice::Vulkan; // Vulkan
|
||||
/// ```
|
||||
pub enum TchDevice {
|
||||
/// CPU device.
|
||||
Cpu,
|
||||
|
||||
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||
/// all Cuda devices found on the system.
|
||||
Cuda(usize),
|
||||
|
||||
/// Metal Performance Shaders device.
|
||||
Mps,
|
||||
|
||||
/// Vulkan device.
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
|
@ -53,6 +61,7 @@ impl Default for TchDevice {
|
|||
}
|
||||
}
|
||||
|
||||
/// The Tch backend.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct TchBackend<E> {
|
||||
_e: E,
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
#![warn(missing_docs)]
|
||||
|
||||
//! Burn Tch Backend
|
||||
|
||||
mod backend;
|
||||
mod element;
|
||||
mod ops;
|
||||
|
|
|
@ -3,8 +3,10 @@ use burn_tensor::{ops::TensorOps, Data, Shape};
|
|||
use libc::c_void;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
/// A reference to a tensor storage.
|
||||
pub type StorageRef = Arc<*mut c_void>;
|
||||
|
||||
/// A tensor that uses the tch backend.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||
pub(crate) tensor: tch::Tensor,
|
||||
|
@ -150,7 +152,9 @@ impl<P: tch::kind::Element, const D: usize> Clone for TchTensor<P, D> {
|
|||
}
|
||||
}
|
||||
|
||||
/// A shape that can be used by LibTorch.
|
||||
pub struct TchShape<const D: usize> {
|
||||
/// The shape's dimensions.
|
||||
pub dims: [i64; D],
|
||||
}
|
||||
|
||||
|
@ -165,6 +169,16 @@ impl<const D: usize> From<Shape<D>> for TchShape<D> {
|
|||
}
|
||||
|
||||
impl<E: tch::kind::Element + Default, const D: usize> TchTensor<E, D> {
|
||||
/// Creates a new tensor from a shape and a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The tensor's data.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor.
|
||||
pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self {
|
||||
let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device);
|
||||
let shape_tch = TchShape::from(data.shape);
|
||||
|
@ -190,6 +204,16 @@ mod utils {
|
|||
}
|
||||
|
||||
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
|
||||
/// Creates an empty tensor from a shape and a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new empty tensor.
|
||||
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
|
||||
let shape_tch = TchShape::from(shape);
|
||||
let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
|
||||
|
|
|
@ -10,6 +10,7 @@ use std::{marker::PhantomData, sync::Mutex};
|
|||
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
/// Wgpu backend.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct WGPUBackend<G: GraphicsApi, F: FloatElement, I: IntElement> {
|
||||
_g: PhantomData<G>,
|
||||
|
|
|
@ -12,9 +12,19 @@
|
|||
/// ```
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub enum WgpuDevice {
|
||||
/// Discrete GPU with the given index. The index is the index of the discrete GPU in the list
|
||||
/// of all discrete GPUs found on the system.
|
||||
DiscreteGpu(usize),
|
||||
|
||||
/// Integrated GPU with the given index. The index is the index of the integrated GPU in the
|
||||
/// list of all integrated GPUs found on the system.
|
||||
IntegratedGpu(usize),
|
||||
|
||||
/// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of
|
||||
/// all virtual GPUs found on the system.
|
||||
VirtualGpu(usize),
|
||||
|
||||
/// CPU.
|
||||
Cpu,
|
||||
}
|
||||
|
||||
|
|
|
@ -8,24 +8,31 @@
|
|||
/// - [DirectX 12](Dx12)
|
||||
/// - [WebGpu](WebGpu)
|
||||
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
||||
/// The wgpu backend.
|
||||
fn backend() -> wgpu::Backend;
|
||||
}
|
||||
|
||||
/// Vulkan graphics API.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Vulkan;
|
||||
|
||||
/// Metal graphics API.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Metal;
|
||||
|
||||
/// OpenGL graphics API.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct OpenGl;
|
||||
|
||||
/// DirectX 11 graphics API.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Dx11;
|
||||
|
||||
/// DirectX 12 graphics API.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Dx12;
|
||||
|
||||
/// WebGpu graphics API.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct WebGpu;
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ pub trait DynamicKernel {
|
|||
fn id(&self) -> String;
|
||||
}
|
||||
|
||||
/// Generates kernel source code by replacing some information using templating.
|
||||
#[macro_export]
|
||||
macro_rules! kernel_wgsl {
|
||||
(
|
||||
|
@ -30,7 +31,7 @@ macro_rules! kernel_wgsl {
|
|||
};
|
||||
}
|
||||
|
||||
/// Generate kernel source code by replacing some information using templating.
|
||||
/// Generates kernel source code by replacing some information using templating.
|
||||
pub struct KernelSettings<
|
||||
K: StaticKernel,
|
||||
E: WgpuElement,
|
||||
|
|
|
@ -8,6 +8,7 @@ kernel_wgsl!(
|
|||
"../template/binary_elemwise_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a binary elementwise kernel.
|
||||
#[macro_export]
|
||||
macro_rules! binary_elemwise {
|
||||
(
|
||||
|
@ -30,6 +31,7 @@ macro_rules! binary_elemwise {
|
|||
};
|
||||
}
|
||||
|
||||
/// Creates a binary elementwise inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! binary_elemwise_inplace {
|
||||
(
|
||||
|
|
|
@ -8,6 +8,7 @@ kernel_wgsl!(
|
|||
"../template/comparison/binary_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a comparison kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison {
|
||||
(
|
||||
|
@ -30,6 +31,7 @@ macro_rules! comparison {
|
|||
};
|
||||
}
|
||||
|
||||
/// Creates a comparison inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison_inplace {
|
||||
(
|
||||
|
|
|
@ -7,6 +7,7 @@ kernel_wgsl!(
|
|||
"../template/comparison/elem_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a comparison elementwise kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison_elem {
|
||||
(
|
||||
|
@ -26,6 +27,7 @@ macro_rules! comparison_elem {
|
|||
};
|
||||
}
|
||||
|
||||
/// Creates a comparison elementwise inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! comparison_elem_inplace {
|
||||
(
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuT
|
|||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
||||
|
||||
/// Creates a unary kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary {
|
||||
(
|
||||
|
@ -54,6 +55,7 @@ macro_rules! unary {
|
|||
};
|
||||
}
|
||||
|
||||
/// Creates a unary inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary_inplace {
|
||||
(
|
||||
|
|
|
@ -7,6 +7,7 @@ kernel_wgsl!(
|
|||
"../template/unary_scalar_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a unary scalar kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary_scalar {
|
||||
(
|
||||
|
@ -42,6 +43,7 @@ macro_rules! unary_scalar {
|
|||
};
|
||||
}
|
||||
|
||||
/// Creates a unary scalar inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary_scalar_inplace {
|
||||
(
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
#![warn(missing_docs)]
|
||||
|
||||
//! Burn WGPU Backend
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
|
|
Loading…
Reference in New Issue