mirror of https://github.com/tracel-ai/burn.git
Add missing docs for burn-wgpu and burn-tch (#427)
This commit is contained in:
parent
825aaa9977
commit
f0b266c8a3
|
@ -19,9 +19,17 @@ use burn_tensor::backend::Backend;
|
||||||
/// let device_vulkan = TchDevice::Vulkan; // Vulkan
|
/// let device_vulkan = TchDevice::Vulkan; // Vulkan
|
||||||
/// ```
|
/// ```
|
||||||
pub enum TchDevice {
|
pub enum TchDevice {
|
||||||
|
/// CPU device.
|
||||||
Cpu,
|
Cpu,
|
||||||
|
|
||||||
|
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||||
|
/// all Cuda devices found on the system.
|
||||||
Cuda(usize),
|
Cuda(usize),
|
||||||
|
|
||||||
|
/// Metal Performance Shaders device.
|
||||||
Mps,
|
Mps,
|
||||||
|
|
||||||
|
/// Vulkan device.
|
||||||
Vulkan,
|
Vulkan,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,6 +61,7 @@ impl Default for TchDevice {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The Tch backend.
|
||||||
#[derive(Clone, Copy, Default, Debug)]
|
#[derive(Clone, Copy, Default, Debug)]
|
||||||
pub struct TchBackend<E> {
|
pub struct TchBackend<E> {
|
||||||
_e: E,
|
_e: E,
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
#![warn(missing_docs)]
|
||||||
|
|
||||||
|
//! Burn Tch Backend
|
||||||
|
|
||||||
mod backend;
|
mod backend;
|
||||||
mod element;
|
mod element;
|
||||||
mod ops;
|
mod ops;
|
||||||
|
|
|
@ -3,8 +3,10 @@ use burn_tensor::{ops::TensorOps, Data, Shape};
|
||||||
use libc::c_void;
|
use libc::c_void;
|
||||||
use std::{marker::PhantomData, sync::Arc};
|
use std::{marker::PhantomData, sync::Arc};
|
||||||
|
|
||||||
|
/// A reference to a tensor storage.
|
||||||
pub type StorageRef = Arc<*mut c_void>;
|
pub type StorageRef = Arc<*mut c_void>;
|
||||||
|
|
||||||
|
/// A tensor that uses the tch backend.
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||||
pub(crate) tensor: tch::Tensor,
|
pub(crate) tensor: tch::Tensor,
|
||||||
|
@ -150,7 +152,9 @@ impl<P: tch::kind::Element, const D: usize> Clone for TchTensor<P, D> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A shape that can be used by LibTorch.
|
||||||
pub struct TchShape<const D: usize> {
|
pub struct TchShape<const D: usize> {
|
||||||
|
/// The shape's dimensions.
|
||||||
pub dims: [i64; D],
|
pub dims: [i64; D],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,6 +169,16 @@ impl<const D: usize> From<Shape<D>> for TchShape<D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: tch::kind::Element + Default, const D: usize> TchTensor<E, D> {
|
impl<E: tch::kind::Element + Default, const D: usize> TchTensor<E, D> {
|
||||||
|
/// Creates a new tensor from a shape and a device.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `data` - The tensor's data.
|
||||||
|
/// * `device` - The device on which the tensor will be allocated.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A new tensor.
|
||||||
pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self {
|
pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self {
|
||||||
let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device);
|
let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device);
|
||||||
let shape_tch = TchShape::from(data.shape);
|
let shape_tch = TchShape::from(data.shape);
|
||||||
|
@ -190,6 +204,16 @@ mod utils {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
|
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
|
||||||
|
/// Creates an empty tensor from a shape and a device.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape` - The shape of the tensor.
|
||||||
|
/// * `device` - The device to create the tensor on.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A new empty tensor.
|
||||||
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
|
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
|
||||||
let shape_tch = TchShape::from(shape);
|
let shape_tch = TchShape::from(shape);
|
||||||
let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
|
let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
|
||||||
|
|
|
@ -10,6 +10,7 @@ use std::{marker::PhantomData, sync::Mutex};
|
||||||
|
|
||||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||||
|
|
||||||
|
/// Wgpu backend.
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Default, Clone)]
|
||||||
pub struct WGPUBackend<G: GraphicsApi, F: FloatElement, I: IntElement> {
|
pub struct WGPUBackend<G: GraphicsApi, F: FloatElement, I: IntElement> {
|
||||||
_g: PhantomData<G>,
|
_g: PhantomData<G>,
|
||||||
|
|
|
@ -12,9 +12,19 @@
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||||
pub enum WgpuDevice {
|
pub enum WgpuDevice {
|
||||||
|
/// Discrete GPU with the given index. The index is the index of the discrete GPU in the list
|
||||||
|
/// of all discrete GPUs found on the system.
|
||||||
DiscreteGpu(usize),
|
DiscreteGpu(usize),
|
||||||
|
|
||||||
|
/// Integrated GPU with the given index. The index is the index of the integrated GPU in the
|
||||||
|
/// list of all integrated GPUs found on the system.
|
||||||
IntegratedGpu(usize),
|
IntegratedGpu(usize),
|
||||||
|
|
||||||
|
/// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of
|
||||||
|
/// all virtual GPUs found on the system.
|
||||||
VirtualGpu(usize),
|
VirtualGpu(usize),
|
||||||
|
|
||||||
|
/// CPU.
|
||||||
Cpu,
|
Cpu,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,24 +8,31 @@
|
||||||
/// - [DirectX 12](Dx12)
|
/// - [DirectX 12](Dx12)
|
||||||
/// - [WebGpu](WebGpu)
|
/// - [WebGpu](WebGpu)
|
||||||
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
||||||
|
/// The wgpu backend.
|
||||||
fn backend() -> wgpu::Backend;
|
fn backend() -> wgpu::Backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Vulkan graphics API.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Vulkan;
|
pub struct Vulkan;
|
||||||
|
|
||||||
|
/// Metal graphics API.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Metal;
|
pub struct Metal;
|
||||||
|
|
||||||
|
/// OpenGL graphics API.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct OpenGl;
|
pub struct OpenGl;
|
||||||
|
|
||||||
|
/// DirectX 11 graphics API.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Dx11;
|
pub struct Dx11;
|
||||||
|
|
||||||
|
/// DirectX 12 graphics API.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Dx12;
|
pub struct Dx12;
|
||||||
|
|
||||||
|
/// WebGpu graphics API.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct WebGpu;
|
pub struct WebGpu;
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ pub trait DynamicKernel {
|
||||||
fn id(&self) -> String;
|
fn id(&self) -> String;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates kernel source code by replacing some information using templating.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! kernel_wgsl {
|
macro_rules! kernel_wgsl {
|
||||||
(
|
(
|
||||||
|
@ -30,7 +31,7 @@ macro_rules! kernel_wgsl {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate kernel source code by replacing some information using templating.
|
/// Generates kernel source code by replacing some information using templating.
|
||||||
pub struct KernelSettings<
|
pub struct KernelSettings<
|
||||||
K: StaticKernel,
|
K: StaticKernel,
|
||||||
E: WgpuElement,
|
E: WgpuElement,
|
||||||
|
|
|
@ -8,6 +8,7 @@ kernel_wgsl!(
|
||||||
"../template/binary_elemwise_inplace.wgsl"
|
"../template/binary_elemwise_inplace.wgsl"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/// Creates a binary elementwise kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! binary_elemwise {
|
macro_rules! binary_elemwise {
|
||||||
(
|
(
|
||||||
|
@ -30,6 +31,7 @@ macro_rules! binary_elemwise {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a binary elementwise inplace kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! binary_elemwise_inplace {
|
macro_rules! binary_elemwise_inplace {
|
||||||
(
|
(
|
||||||
|
|
|
@ -8,6 +8,7 @@ kernel_wgsl!(
|
||||||
"../template/comparison/binary_inplace.wgsl"
|
"../template/comparison/binary_inplace.wgsl"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/// Creates a comparison kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! comparison {
|
macro_rules! comparison {
|
||||||
(
|
(
|
||||||
|
@ -30,6 +31,7 @@ macro_rules! comparison {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a comparison inplace kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! comparison_inplace {
|
macro_rules! comparison_inplace {
|
||||||
(
|
(
|
||||||
|
|
|
@ -7,6 +7,7 @@ kernel_wgsl!(
|
||||||
"../template/comparison/elem_inplace.wgsl"
|
"../template/comparison/elem_inplace.wgsl"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/// Creates a comparison elementwise kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! comparison_elem {
|
macro_rules! comparison_elem {
|
||||||
(
|
(
|
||||||
|
@ -26,6 +27,7 @@ macro_rules! comparison_elem {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a comparison elementwise inplace kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! comparison_elem_inplace {
|
macro_rules! comparison_elem_inplace {
|
||||||
(
|
(
|
||||||
|
|
|
@ -4,6 +4,7 @@ use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuT
|
||||||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||||
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
||||||
|
|
||||||
|
/// Creates a unary kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! unary {
|
macro_rules! unary {
|
||||||
(
|
(
|
||||||
|
@ -54,6 +55,7 @@ macro_rules! unary {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a unary inplace kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! unary_inplace {
|
macro_rules! unary_inplace {
|
||||||
(
|
(
|
||||||
|
|
|
@ -7,6 +7,7 @@ kernel_wgsl!(
|
||||||
"../template/unary_scalar_inplace.wgsl"
|
"../template/unary_scalar_inplace.wgsl"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/// Creates a unary scalar kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! unary_scalar {
|
macro_rules! unary_scalar {
|
||||||
(
|
(
|
||||||
|
@ -42,6 +43,7 @@ macro_rules! unary_scalar {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a unary scalar inplace kernel.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! unary_scalar_inplace {
|
macro_rules! unary_scalar_inplace {
|
||||||
(
|
(
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
#![warn(missing_docs)]
|
||||||
|
|
||||||
|
//! Burn WGPU Backend
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate derive_new;
|
extern crate derive_new;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue