Add missing docs for burn-wgpu and burn-tch (#427)

This commit is contained in:
Dilshod Tadjibaev 2023-06-23 08:31:37 -05:00 committed by GitHub
parent 825aaa9977
commit f0b266c8a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 71 additions and 1 deletions

View File

@ -19,9 +19,17 @@ use burn_tensor::backend::Backend;
/// let device_vulkan = TchDevice::Vulkan; // Vulkan
/// ```
pub enum TchDevice {
/// CPU device.
Cpu,
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
/// Metal Performance Shaders device.
Mps,
/// Vulkan device.
Vulkan,
}
@ -53,6 +61,7 @@ impl Default for TchDevice {
}
}
/// The Tch backend.
#[derive(Clone, Copy, Default, Debug)]
pub struct TchBackend<E> {
_e: E,

View File

@ -1,3 +1,7 @@
#![warn(missing_docs)]
//! Burn Tch Backend
mod backend;
mod element;
mod ops;

View File

@ -3,8 +3,10 @@ use burn_tensor::{ops::TensorOps, Data, Shape};
use libc::c_void;
use std::{marker::PhantomData, sync::Arc};
/// A reference to a tensor storage.
pub type StorageRef = Arc<*mut c_void>;
/// A tensor that uses the tch backend.
#[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
pub(crate) tensor: tch::Tensor,
@ -150,7 +152,9 @@ impl<P: tch::kind::Element, const D: usize> Clone for TchTensor<P, D> {
}
}
/// A shape that can be used by LibTorch.
pub struct TchShape<const D: usize> {
/// The shape's dimensions.
pub dims: [i64; D],
}
@ -165,6 +169,16 @@ impl<const D: usize> From<Shape<D>> for TchShape<D> {
}
impl<E: tch::kind::Element + Default, const D: usize> TchTensor<E, D> {
/// Creates a new tensor from a shape and a device.
///
/// # Arguments
///
/// * `data` - The tensor's data.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// A new tensor.
pub fn from_data(data: Data<E, D>, device: tch::Device) -> Self {
let tensor = tch::Tensor::from_slice(data.value.as_slice()).to(device);
let shape_tch = TchShape::from(data.shape);
@ -190,6 +204,16 @@ mod utils {
}
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
/// Creates an empty tensor from a shape and a device.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// A new empty tensor.
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
let shape_tch = TchShape::from(shape);
let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));

View File

@ -10,6 +10,7 @@ use std::{marker::PhantomData, sync::Mutex};
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
/// Wgpu backend.
#[derive(Debug, Default, Clone)]
pub struct WGPUBackend<G: GraphicsApi, F: FloatElement, I: IntElement> {
_g: PhantomData<G>,

View File

@ -12,9 +12,19 @@
/// ```
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub enum WgpuDevice {
/// Discrete GPU with the given index. The index is the index of the discrete GPU in the list
/// of all discrete GPUs found on the system.
DiscreteGpu(usize),
/// Integrated GPU with the given index. The index is the index of the integrated GPU in the
/// list of all integrated GPUs found on the system.
IntegratedGpu(usize),
/// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of
/// all virtual GPUs found on the system.
VirtualGpu(usize),
/// CPU.
Cpu,
}

View File

@ -8,24 +8,31 @@
/// - [DirectX 12](Dx12)
/// - [WebGpu](WebGpu)
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
/// The wgpu backend.
fn backend() -> wgpu::Backend;
}
/// Vulkan graphics API.
#[derive(Default, Debug, Clone)]
pub struct Vulkan;
/// Metal graphics API.
#[derive(Default, Debug, Clone)]
pub struct Metal;
/// OpenGL graphics API.
#[derive(Default, Debug, Clone)]
pub struct OpenGl;
/// DirectX 11 graphics API.
#[derive(Default, Debug, Clone)]
pub struct Dx11;
/// DirectX 12 graphics API.
#[derive(Default, Debug, Clone)]
pub struct Dx12;
/// WebGpu graphics API.
#[derive(Default, Debug, Clone)]
pub struct WebGpu;

View File

@ -13,6 +13,7 @@ pub trait DynamicKernel {
fn id(&self) -> String;
}
/// Generates kernel source code by replacing some information using templating.
#[macro_export]
macro_rules! kernel_wgsl {
(
@ -30,7 +31,7 @@ macro_rules! kernel_wgsl {
};
}
/// Generate kernel source code by replacing some information using templating.
/// Generates kernel source code by replacing some information using templating.
pub struct KernelSettings<
K: StaticKernel,
E: WgpuElement,

View File

@ -8,6 +8,7 @@ kernel_wgsl!(
"../template/binary_elemwise_inplace.wgsl"
);
/// Creates a binary elementwise kernel.
#[macro_export]
macro_rules! binary_elemwise {
(
@ -30,6 +31,7 @@ macro_rules! binary_elemwise {
};
}
/// Creates a binary elementwise inplace kernel.
#[macro_export]
macro_rules! binary_elemwise_inplace {
(

View File

@ -8,6 +8,7 @@ kernel_wgsl!(
"../template/comparison/binary_inplace.wgsl"
);
/// Creates a comparison kernel.
#[macro_export]
macro_rules! comparison {
(
@ -30,6 +31,7 @@ macro_rules! comparison {
};
}
/// Creates a comparison inplace kernel.
#[macro_export]
macro_rules! comparison_inplace {
(

View File

@ -7,6 +7,7 @@ kernel_wgsl!(
"../template/comparison/elem_inplace.wgsl"
);
/// Creates a comparison elementwise kernel.
#[macro_export]
macro_rules! comparison_elem {
(
@ -26,6 +27,7 @@ macro_rules! comparison_elem {
};
}
/// Creates a comparison elementwise inplace kernel.
#[macro_export]
macro_rules! comparison_elem_inplace {
(

View File

@ -4,6 +4,7 @@ use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuT
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
/// Creates a unary kernel.
#[macro_export]
macro_rules! unary {
(
@ -54,6 +55,7 @@ macro_rules! unary {
};
}
/// Creates a unary inplace kernel.
#[macro_export]
macro_rules! unary_inplace {
(

View File

@ -7,6 +7,7 @@ kernel_wgsl!(
"../template/unary_scalar_inplace.wgsl"
);
/// Creates a unary scalar kernel.
#[macro_export]
macro_rules! unary_scalar {
(
@ -42,6 +43,7 @@ macro_rules! unary_scalar {
};
}
/// Creates a unary scalar inplace kernel.
#[macro_export]
macro_rules! unary_scalar_inplace {
(

View File

@ -1,3 +1,7 @@
#![warn(missing_docs)]
//! Burn WGPU Backend
#[macro_use]
extern crate derive_new;