Refactor element type to be decoupled from runtime (#1693)

This commit is contained in:
Guillaume Lagrange 2024-04-26 08:53:55 -04:00 committed by GitHub
parent 67ec06d5d8
commit ce2429eb10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 365 additions and 284 deletions

View File

@ -1,4 +1,4 @@
use crate::{codegen::Compiler, tensor::JitTensor, PrecisionBridge, Runtime};
use crate::{tensor::JitTensor, FloatElement, IntElement, PrecisionBridge, Runtime};
use burn_tensor::backend::Backend;
use rand::{rngs::StdRng, SeedableRng};
use std::{marker::PhantomData, sync::Mutex};
@ -7,16 +7,23 @@ pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
#[derive(new)]
pub struct JitBackend<R: Runtime> {
pub struct JitBackend<R: Runtime, F: FloatElement, I: IntElement> {
_runtime: PhantomData<R>,
_float_elem: PhantomData<F>,
_int_elem: PhantomData<I>,
}
impl<R: Runtime> Backend for JitBackend<R> {
impl<R, F, I> Backend for JitBackend<R, F, I>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
type Device = R::Device;
type FullPrecisionBridge = PrecisionBridge<R::FullPrecisionRuntime>;
type FloatElem = <R::Compiler as Compiler>::Float;
type IntElem = <R::Compiler as Compiler>::Int;
type FullPrecisionBridge = PrecisionBridge<R, f32, i32>;
type FloatElem = F;
type IntElem = I;
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
@ -42,19 +49,19 @@ impl<R: Runtime> Backend for JitBackend<R> {
}
}
impl<R: Runtime> core::fmt::Debug for JitBackend<R> {
impl<R: Runtime, F: FloatElement, I: IntElement> core::fmt::Debug for JitBackend<R, F, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name()))
}
}
impl<R: Runtime> Clone for JitBackend<R> {
impl<R: Runtime, F: FloatElement, I: IntElement> Clone for JitBackend<R, F, I> {
fn clone(&self) -> Self {
Self::new()
}
}
impl<R: Runtime> Default for JitBackend<R> {
impl<R: Runtime, F: FloatElement, I: IntElement> Default for JitBackend<R, F, I> {
fn default() -> Self {
Self::new()
}

View File

@ -1,4 +1,6 @@
use crate::{kernel, ops::to_device, tensor::JitTensor, JitBackend, Runtime};
use crate::{
kernel, ops::to_device, tensor::JitTensor, FloatElement, IntElement, JitBackend, Runtime,
};
use burn_tensor::{
backend::BackendBridge,
ops::{FloatElem, FloatTensor},
@ -7,26 +9,31 @@ use core::marker::PhantomData;
/// Handle precision conversion for the jit backend.
#[derive(Debug)]
pub struct PrecisionBridge<R> {
pub struct PrecisionBridge<R, F: FloatElement, I: IntElement> {
_runtime: PhantomData<R>,
_float_elem: PhantomData<F>,
_int_elem: PhantomData<I>,
}
impl<ROrigin, RTarget> BackendBridge<JitBackend<ROrigin>> for PrecisionBridge<RTarget>
impl<R, FOrigin, IOrigin, FTarget, ITarget> BackendBridge<JitBackend<R, FOrigin, IOrigin>>
for PrecisionBridge<R, FTarget, ITarget>
where
ROrigin: Runtime,
RTarget:
Runtime<Device = ROrigin::Device, Server = ROrigin::Server, Channel = ROrigin::Channel>,
R: Runtime,
FOrigin: FloatElement,
IOrigin: IntElement,
FTarget: FloatElement,
ITarget: IntElement,
{
type Target = JitBackend<RTarget>;
type Target = JitBackend<R, FTarget, ITarget>;
fn into_target<const D: usize>(
tensor: FloatTensor<JitBackend<ROrigin>, D>,
tensor: FloatTensor<JitBackend<R, FOrigin, IOrigin>, D>,
device: Option<burn_tensor::Device<Self::Target>>,
) -> FloatTensor<Self::Target, D> {
let tensor = kernel::cast::<
ROrigin,
FloatElem<JitBackend<ROrigin>>,
FloatElem<JitBackend<RTarget>>,
R,
FloatElem<JitBackend<R, FOrigin, IOrigin>>,
FloatElem<JitBackend<R, FTarget, ITarget>>,
D,
>(tensor);
@ -42,12 +49,12 @@ where
fn from_target<const D: usize>(
tensor: FloatTensor<Self::Target, D>,
device: Option<burn_tensor::Device<JitBackend<ROrigin>>>,
) -> FloatTensor<JitBackend<ROrigin>, D> {
device: Option<burn_tensor::Device<JitBackend<R, FOrigin, IOrigin>>>,
) -> FloatTensor<JitBackend<R, FOrigin, IOrigin>, D> {
let tensor = kernel::cast::<
RTarget,
FloatElem<JitBackend<RTarget>>,
FloatElem<JitBackend<ROrigin>>,
R,
FloatElem<JitBackend<R, FTarget, ITarget>>,
FloatElem<JitBackend<R, FOrigin, IOrigin>>,
D,
>(tensor);
// The line below does the backend type cast.

View File

@ -1,5 +1,4 @@
use super::dialect::gpu;
use crate::{FloatElement, IntElement};
use std::fmt::Display;
/// Compiles the [gpu representation](gpu::ComputeShader) into its own representation that can be
@ -7,16 +6,6 @@ use std::fmt::Display;
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
/// The representation for the compiled code.
type Representation: Display;
/// The float element type used for compilation.
type Float: FloatElement;
/// The int element type used for compilation.
type Int: IntElement;
/// The compiler that can be used to generate full precision shaders.
type FullPrecisionCompiler: Compiler<
Representation = Self::Representation,
Float = f32,
Int = i32,
>;
/// Compiles the [gpu shader](gpu::ComputeShader) into the compiler's representation.
fn compile(shader: gpu::ComputeShader) -> Self::Representation;

View File

@ -390,13 +390,13 @@ impl From<bool> for Variable {
impl From<i32> for Variable {
fn from(value: i32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Int)
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32))
}
}
impl From<f32> for Variable {
fn from(value: f32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Float)
Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32))
}
}

View File

@ -17,11 +17,25 @@ pub enum Visibility {
ReadWrite,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum FloatKind {
F32,
F64,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum IntKind {
I32,
I64,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum Elem {
Float,
Int,
Float(FloatKind),
Int(IntKind),
UInt,
Bool,
}
@ -35,8 +49,9 @@ impl From<Elem> for Item {
impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Float => f.write_str("float"),
Self::Int => f.write_str("int"),
// NOTE: we'll eventually want to differentiate between int/float types
Self::Float(_) => f.write_str("float"),
Self::Int(_) => f.write_str("int"),
Self::UInt => f.write_str("uint"),
Self::Bool => f.write_str("bool"),
}

View File

@ -312,8 +312,8 @@ fn create_scalar_handles<R: Runtime, E1: JitElement, E2: JitElement, E3: JitElem
) -> Vec<Handle<R::Server>> {
// It is crucial that scalars follow this order: float, int, uint
let element_priority = |elem: Elem| match elem {
Elem::Float => 0,
Elem::Int => 1,
Elem::Float(_) => 0,
Elem::Int(_) => 1,
Elem::UInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
};

View File

@ -59,7 +59,7 @@ impl JitElement for i32 {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Int
gpu::Elem::Int(gpu::IntKind::I32)
}
fn maximum_value() -> Self {
// Seems to cause problem for some GPU
@ -82,7 +82,7 @@ impl JitElement for f32 {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Float
gpu::Elem::Float(gpu::FloatKind::F32)
}
fn maximum_value() -> Self {
f32::MAX

View File

@ -1,6 +1,7 @@
use super::{ElementWise, ElementWiseState};
use crate::{
element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, JitBackend, Runtime,
element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, FloatElement, IntElement,
JitBackend, Runtime,
};
use burn_compute::client::ComputeClient;
use burn_fusion::{client::MutexFusionClient, FusionBackend};
@ -25,8 +26,13 @@ pub enum JitOptimizationState {
ElementWise(ElementWiseState),
}
impl<R: Runtime> burn_fusion::Optimization<JitBackend<R>> for JitOptimization<R> {
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend<R>>) {
impl<R, F, I> burn_fusion::Optimization<JitBackend<R, F, I>> for JitOptimization<R>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend<R, F, I>>) {
match self {
Self::ElementWise(op) => op.execute(context),
}
@ -53,7 +59,7 @@ impl<R: Runtime> burn_fusion::Optimization<JitBackend<R>> for JitOptimization<R>
}
}
impl<R: Runtime> ReprBackend for JitBackend<R> {
impl<R: Runtime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
type Handle = JitFusionHandle<R>;
fn float_tensor<const D: usize>(
@ -96,7 +102,7 @@ impl<R: Runtime> ReprBackend for JitBackend<R> {
}
}
impl<R: Runtime> FusionBackend for JitBackend<R> {
impl<R: Runtime, F: FloatElement, I: IntElement> FusionBackend for JitBackend<R, F, I> {
type OptimizationState = JitOptimizationState;
type Optimization = JitOptimization<R>;
type FusionClient = MutexFusionClient<Self>;
@ -104,7 +110,7 @@ impl<R: Runtime> FusionBackend for JitBackend<R> {
fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
vec![Box::new(ElementWiseBuilder::new(device))]
vec![Box::new(ElementWiseBuilder::<R, F, I>::new(device))]
}
}

View File

@ -1,3 +1,5 @@
use core::marker::PhantomData;
use super::{optimization::ElementWise, CompilationPhase};
use crate::{
codegen::dialect::gpu::{
@ -5,7 +7,7 @@ use crate::{
},
element::JitElement,
fusion::{tracing::TraceBuilder, JitOptimization},
JitBackend, Runtime,
FloatElement, IntElement, JitBackend, Runtime,
};
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
use burn_tensor::{
@ -19,15 +21,22 @@ use burn_tensor::{
};
/// Fused element wise operations that are normally memory bound.
pub(crate) struct ElementWiseBuilder<R: Runtime> {
pub(crate) struct ElementWiseBuilder<R: Runtime, F: FloatElement, I: IntElement> {
builder: TraceBuilder,
current_output_shape: Vec<usize>,
status: OptimizationStatus,
num_added: usize,
device: R::Device,
_float_elem: PhantomData<F>,
_int_elem: PhantomData<I>,
}
impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<R> {
impl<R, F, I> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<R, F, I>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
fn register(&mut self, ops: &OperationDescription) {
if let OptimizationStatus::Closed = self.status {
return;
@ -35,31 +44,31 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<
match ops {
OperationDescription::BaseFloat(ops) => {
if !self.register_base::<FloatElem<JitBackend<R>>>(ops) {
if !self.register_base::<FloatElem<JitBackend<R, F, I>>>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
OperationDescription::BaseInt(ops) => {
if !self.register_base::<IntElem<JitBackend<R>>>(ops) {
if !self.register_base::<IntElem<JitBackend<R, F, I>>>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
OperationDescription::Float(ops) => {
if !self.register_float::<FloatElem<JitBackend<R>>>(ops) {
if !self.register_float::<FloatElem<JitBackend<R, F, I>>>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
OperationDescription::NumericFloat(ops) => {
if !self.register_numeric::<FloatElem<JitBackend<R>>, _>(ops) {
if !self.register_numeric::<FloatElem<JitBackend<R, F, I>>, _>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
OperationDescription::NumericInt(ops) => {
if !self.register_numeric::<IntElem<JitBackend<R>>, _>(ops) {
if !self.register_numeric::<IntElem<JitBackend<R, F, I>>, _>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
@ -110,14 +119,16 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<
}
}
impl<R: Runtime> ElementWiseBuilder<R> {
pub fn new(device: Device<JitBackend<R>>) -> Self {
impl<R: Runtime, F: FloatElement, I: IntElement> ElementWiseBuilder<R, F, I> {
pub fn new(device: Device<JitBackend<R, F, I>>) -> Self {
Self {
builder: TraceBuilder::new(),
num_added: 0,
current_output_shape: Vec::new(),
status: OptimizationStatus::Open,
device,
_float_elem: PhantomData,
_int_elem: PhantomData,
}
}

View File

@ -8,7 +8,7 @@ use crate::{
codegen::dialect::gpu::WorkgroupSize,
compute::JitAutotuneKey,
fusion::{kernel::FusionKernel, tracing::Trace},
JitBackend, Runtime,
FloatElement, IntElement, JitBackend, Runtime,
};
use burn_common::id::IdGenerator;
use burn_compute::client::ComputeClient;
@ -66,7 +66,10 @@ impl<R: Runtime> ElementWise<R, CompilationPhase> {
}
impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
pub(crate) fn execute(&mut self, context: &mut Context<'_, JitBackend<R>>) {
pub(crate) fn execute<F: FloatElement, I: IntElement>(
&mut self,
context: &mut Context<'_, JitBackend<R, F, I>>,
) {
let client = R::client(&self.device);
let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new(
@ -81,9 +84,9 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
}
}
fn run_kernel(
fn run_kernel<F: FloatElement, I: IntElement>(
&mut self,
context: &mut Context<'_, JitBackend<R>>,
context: &mut Context<'_, JitBackend<R, F, I>>,
client: ComputeClient<R::Server, R::Channel>,
fastest_set_index: usize,
) {
@ -106,9 +109,9 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
kernel.execute();
}
fn run_autotune(
fn run_autotune<F: FloatElement, I: IntElement>(
&mut self,
context: &mut Context<'_, JitBackend<R>>,
context: &mut Context<'_, JitBackend<R, F, I>>,
client: ComputeClient<R::Server, R::Channel>,
key: JitAutotuneKey,
) {
@ -152,9 +155,9 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
}
/// The first output is chosen when possible, otherwise the first input is chosen.
pub(crate) fn autotune_shape<'a>(
pub(crate) fn autotune_shape<'a, F: FloatElement, I: IntElement>(
&self,
context: &mut Context<'a, JitBackend<R>>,
context: &mut Context<'a, JitBackend<R, F, I>>,
) -> &'a [usize] {
let info = self.trace.running();

View File

@ -9,6 +9,8 @@ use crate::fusion::strides_dyn_rank;
use crate::fusion::JitFusionHandle;
use crate::gpu::ComputeShader;
use crate::kernel::GpuComputeShaderPhase;
use crate::FloatElement;
use crate::IntElement;
use crate::JitBackend;
use crate::Runtime;
use burn_compute::client::ComputeClient;
@ -106,14 +108,19 @@ impl<R: Runtime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
}
impl<R: Runtime> FusionKernel<R> {
pub fn create<K: FusionKernelFactory<R>>(
pub fn create<K, F, I>(
factory: &K,
running_info: &ExecutionInfo<'_>,
context: &mut Context<'_, JitBackend<R>>,
device: Device<JitBackend<R>>,
context: &mut Context<'_, JitBackend<R, F, I>>,
device: Device<JitBackend<R, F, I>>,
client: ComputeClient<R::Server, R::Channel>,
stateful: bool,
) -> ExecutableKernel<R> {
) -> ExecutableKernel<R>
where
K: FusionKernelFactory<R>,
F: FloatElement,
I: IntElement,
{
let (handles_input, inputs_description_updated, outputs_description_updated) =
process_inputs_outputs(
&running_info.inputs,
@ -266,16 +273,21 @@ fn register_info_tensor<R: Runtime>(
}
}
fn process_inputs_outputs<'a, R: Runtime>(
fn process_inputs_outputs<'a, R, F, I>(
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
context: &'a mut Context<'_, JitBackend<R>>,
context: &'a mut Context<'_, JitBackend<R, F, I>>,
stateful: bool,
) -> (
Vec<JitFusionHandle<R>>,
Vec<&'a TensorDescription>,
Vec<&'a TensorDescription>,
) {
)
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
let mut inputs_description_updated = Vec::with_capacity(inputs.len());
let mut outputs_description_updated = Vec::with_capacity(outputs.len());
let mut handles_input = Vec::new();

View File

@ -90,14 +90,14 @@ impl TraceBuilder {
/// Create a variable from an input [scalar](Element).
pub fn scalar<E: Element>(&mut self, _value: &E, elem_type: gpu::Elem) -> gpu::Variable {
match elem_type {
gpu::Elem::Float => {
gpu::Elem::Float(_) => {
let var = self
.scope
.read_scalar(self.scalars.num_float as u16, elem_type);
self.scalars.num_float += 1;
var
}
gpu::Elem::Int => {
gpu::Elem::Int(_) => {
let var = self
.scope
.read_scalar(self.scalars.num_int as u16, elem_type);

View File

@ -57,9 +57,10 @@ impl Trace {
})
.collect::<Vec<_>>();
// NOTE: we might want to pass a struct including all inputs/outputs metadata instead of 3 arrays
if self.scalars.num_float > 0 {
inputs.push(InputInfo::Scalar {
elem: gpu::Elem::Float,
elem: gpu::Elem::Float(gpu::FloatKind::F32),
size: self.scalars.num_float,
})
}
@ -73,7 +74,7 @@ impl Trace {
if self.scalars.num_int > 0 {
inputs.push(InputInfo::Scalar {
elem: gpu::Elem::Int,
elem: gpu::Elem::Int(gpu::IntKind::I32),
size: self.scalars.num_int,
})
}

View File

@ -6,7 +6,7 @@ use crate::{
OutputInfo, WorkgroupLaunch,
},
element::JitElement,
gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility},
gpu::{gpu, ComputeShader, Elem, IntKind, Scope, Variable, Visibility},
kernel::{self, GpuComputeShaderPhase},
ops::{
numeric::{empty_device, zeros_device},
@ -99,8 +99,8 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let groups = Variable::GlobalScalar(6, Elem::UInt);
let stride_0_i = scope.create_local(Elem::Int);
let stride_1_i = scope.create_local(Elem::Int);
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, stride_0_i = cast(conv_stride_0));
gpu!(scope, stride_1_i = cast(conv_stride_1));
@ -139,15 +139,15 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
gpu!(scope, ic_end = ic_start + ic_tmp);
let tmp_u = scope.create_local(Elem::UInt);
let tmp_i = scope.create_local(Elem::Int);
let zero_i = scope.zero(Elem::Int);
let one_i = scope.create_with_value(1, Elem::Int);
let tmp_i = scope.create_local(Elem::Int(IntKind::I32));
let zero_i = scope.zero(Elem::Int(IntKind::I32));
let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32));
let kms_u = scope.create_local(Elem::UInt);
let kms_0 = scope.create_local(Elem::Int);
let kms_1 = scope.create_local(Elem::Int);
let ih_start_tmp = scope.create_local(Elem::Int);
let iw_start_tmp = scope.create_local(Elem::Int);
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ih_start = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);

View File

@ -1,6 +1,6 @@
use crate::codegen::dialect::gpu::{gpu, Elem, Scope, Variable};
use crate::codegen::Execution;
use crate::gpu::ComputeShader;
use crate::gpu::{ComputeShader, IntKind};
use crate::{
codegen::{
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
@ -80,7 +80,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for GatherEagerKernel<R, E
fn compile(&self) -> ComputeShader {
let mut scope = gpu::Scope::root();
let item_tensor = E::gpu_elem().into();
let item_indices: gpu::Item = gpu::Elem::Int.into();
let item_indices: gpu::Item = gpu::Elem::Int(IntKind::I32).into();
let tensor = gpu::Variable::GlobalInputArray(0, item_tensor);
let indices = scope.read_array(1, item_indices);
@ -103,7 +103,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for GatherEagerKernel<R, E
visibility: gpu::Visibility::Read,
};
let indices = InputInfo::Array {
item: gpu::Elem::Int.into(),
item: gpu::Elem::Int(IntKind::I32).into(),
visibility: gpu::Visibility::Read,
};
let out = OutputInfo::Array { item: item_tensor };

View File

@ -133,10 +133,10 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for ScatterEagerKernel<R,
fn compile(&self) -> ComputeShader {
let mut scope = gpu::Scope::root();
let item_value = E::gpu_elem().into();
let item_indices: gpu::Item = gpu::Elem::Int.into();
let item_indices: gpu::Item = gpu::Elem::Int(gpu::IntKind::I32).into();
let input_output = gpu::Variable::GlobalInputArray(0, item_value);
let indices = gpu::Variable::GlobalInputArray(1, Elem::Int.into());
let indices = gpu::Variable::GlobalInputArray(1, Elem::Int(gpu::IntKind::I32).into());
let value = gpu::Variable::GlobalInputArray(2, item_value);
scope.write_global_custom(input_output);

View File

@ -1,6 +1,6 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility},
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
@ -74,7 +74,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for SelectEagerKernel<R, E
fn compile(&self) -> ComputeShader {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let item_indices: Item = Elem::Int.into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let input = Variable::GlobalInputArray(0, item);
let indices = Variable::GlobalInputArray(1, item_indices);

View File

@ -1,6 +1,6 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Branch, Elem, Item, Scope, Variable, Visibility},
dialect::gpu::{gpu, Branch, Elem, IntKind, Item, Scope, Variable, Visibility},
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
WorkgroupLaunch,
},
@ -132,7 +132,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for SelectAssignEagerKerne
fn compile(&self) -> ComputeShader {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let item_indices: Item = Elem::Int.into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let tensor = Variable::GlobalInputArray(0, item);
let value = Variable::GlobalInputArray(1, item);

View File

@ -17,16 +17,18 @@ struct InterpolateBicubicEagerKernel<R, E> {
_elem: PhantomData<E>,
}
struct InterpolateBicubicShader {
struct InterpolateBicubicShader<E> {
input: Variable,
output: Variable,
_elem: PhantomData<E>,
}
impl InterpolateBicubicShader {
impl<E: JitElement> InterpolateBicubicShader<E> {
pub(crate) fn expand(self, scope: &mut Scope) {
let input = self.input;
let output = self.output;
let id = Variable::Id;
let elem = E::gpu_elem();
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
@ -83,20 +85,20 @@ impl InterpolateBicubicShader {
let input_height = scope.create_local(Elem::UInt);
let output_height = scope.create_local(Elem::UInt);
let output_height_float = scope.create_local(Elem::Float);
let output_height_float = scope.create_local(elem);
let input_width = scope.create_local(Elem::UInt);
let output_width = scope.create_local(Elem::UInt);
let output_width_float = scope.create_local(Elem::Float);
let output_width_float = scope.create_local(elem);
let frac = scope.create_local(Elem::Float);
let frac = scope.create_local(elem);
let numerator = scope.create_local(Elem::UInt);
let numerator_float = scope.create_local(Elem::Float);
let numerator_float = scope.create_local(elem);
let not_zero = scope.create_local(Elem::Bool);
let y_in_float = scope.create_local(Elem::Float);
let y_in_float = scope.create_local(elem);
let y_in = scope.create_local(Elem::UInt);
let yw = scope.create_local(Elem::Float);
let yw = scope.create_local(elem);
let y_tmp = scope.create_local(Elem::UInt);
gpu!(scope, input_height = input_shape_2 - 1u32);
@ -123,9 +125,9 @@ impl InterpolateBicubicShader {
gpu!(scope, y_tmp = y_in + 2u32);
let y3 = Self::min(scope, y_tmp, input_height);
let x_in_float = scope.create_local(Elem::Float);
let x_in_float = scope.create_local(elem);
let x_in = scope.create_local(Elem::UInt);
let xw = scope.create_local(Elem::Float);
let xw = scope.create_local(elem);
let x_tmp = scope.create_local(Elem::UInt);
gpu!(scope, input_width = input_shape_3 - 1u32);
@ -374,7 +376,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for InterpolateBicubicEage
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
InterpolateBicubicShader { input, output }.expand(&mut scope);
InterpolateBicubicShader {
input,
output,
_elem: PhantomData::<E>,
}
.expand(&mut scope);
scope.write_global_custom(output);

View File

@ -17,16 +17,18 @@ struct InterpolateNearestEagerKernel<R, E> {
_elem: PhantomData<E>,
}
struct InterpolateNearestShader {
struct InterpolateNearestShader<E> {
input: Variable,
output: Variable,
_elem: PhantomData<E>,
}
impl InterpolateNearestShader {
impl<E: JitElement> InterpolateNearestShader<E> {
pub(crate) fn expand(self, scope: &mut Scope) {
let input = self.input;
let output = self.output;
let id = Variable::Id;
let elem = E::gpu_elem();
let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
@ -81,11 +83,11 @@ impl InterpolateNearestShader {
gpu!(scope, w = id / output_stride_3);
gpu!(scope, w = w % output_shape_3);
let factor_float = scope.create_local(Elem::Float);
let numerator_float = scope.create_local(Elem::Float);
let denominator_float = scope.create_local(Elem::Float);
let x = scope.create_local(Elem::Float);
let y = scope.create_local(Elem::Float);
let factor_float = scope.create_local(elem);
let numerator_float = scope.create_local(elem);
let denominator_float = scope.create_local(elem);
let x = scope.create_local(elem);
let y = scope.create_local(elem);
let xu = scope.create_local(Elem::UInt);
let yu = scope.create_local(Elem::UInt);
@ -130,7 +132,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for InterpolateNearestEage
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
InterpolateNearestShader { input, output }.expand(&mut scope);
InterpolateNearestShader {
input,
output,
_elem: PhantomData::<E>,
}
.expand(&mut scope);
scope.write_global_custom(output);

View File

@ -17,12 +17,13 @@ struct InterpolateNearestBackwardEagerKernel<R, E> {
_elem: PhantomData<E>,
}
struct InterpolateNearestBackwardShader {
struct InterpolateNearestBackwardShader<E> {
out_grad: Variable,
output: Variable,
_elem: PhantomData<E>,
}
impl InterpolateNearestBackwardShader {
impl<E: JitElement> InterpolateNearestBackwardShader<E> {
fn expand(self, scope: &mut Scope) {
let grad = self.out_grad;
let output = self.output;
@ -134,8 +135,9 @@ impl InterpolateNearestBackwardShader {
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let elem = E::gpu_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index = input_index * output_size);
@ -154,8 +156,9 @@ impl InterpolateNearestBackwardShader {
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let elem = E::gpu_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);
@ -189,7 +192,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
let out_grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
InterpolateNearestBackwardShader { out_grad, output }.expand(&mut scope);
InterpolateNearestBackwardShader {
out_grad,
output,
_elem: PhantomData::<E>,
}
.expand(&mut scope);
scope.write_global_custom(output);

View File

@ -18,10 +18,11 @@ use std::marker::PhantomData;
use super::simple_launch_options;
#[derive(new, Debug)]
struct MatmulEagerKernel<R: Runtime> {
struct MatmulEagerKernel<R: Runtime, E: JitElement> {
workgroup_size_x: usize,
workgroup_size_y: usize,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
struct MatmulComputeShader {
@ -151,7 +152,7 @@ impl MatmulComputeShader {
}
}
impl<R: Runtime> GpuComputeShaderPhase for MatmulEagerKernel<R> {
impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for MatmulEagerKernel<R, E> {
fn compile(&self) -> ComputeShader {
assert_eq!(
self.workgroup_size_x, self.workgroup_size_y,
@ -159,9 +160,17 @@ impl<R: Runtime> GpuComputeShaderPhase for MatmulEagerKernel<R> {
);
let mut scope = gpu::Scope::root();
let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into());
let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into());
let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into());
let elem = E::gpu_elem();
assert!(
elem == gpu::Elem::Float(gpu::FloatKind::F32)
|| elem == gpu::Elem::Float(gpu::FloatKind::F64),
"Only float elements are supported."
);
let item = elem.into();
let lhs = gpu::Variable::GlobalInputArray(0, item);
let rhs = gpu::Variable::GlobalInputArray(1, item);
let out = gpu::Variable::GlobalOutputArray(0, item);
scope.write_global_custom(out);
@ -172,16 +181,14 @@ impl<R: Runtime> GpuComputeShaderPhase for MatmulEagerKernel<R> {
.expand(&mut scope);
let lhs = InputInfo::Array {
item: gpu::Elem::Float.into(),
item,
visibility: gpu::Visibility::Read,
};
let rhs = InputInfo::Array {
item: gpu::Elem::Float.into(),
item,
visibility: gpu::Visibility::Read,
};
let out = OutputInfo::Array {
item: gpu::Elem::Float.into(),
};
let out = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![lhs, rhs],
@ -236,7 +243,7 @@ pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
workgroup_size_y,
);
let kernel = MatmulEagerKernel::<R>::new(workgroup_size_x, workgroup_size_y);
let kernel = MatmulEagerKernel::<R, E>::new(workgroup_size_x, workgroup_size_y);
Execution::start(kernel, rhs.client)
.inputs(&[

View File

@ -26,18 +26,27 @@ struct MatmulTiling2d<E: JitElement> {
}
#[derive(new, Debug)]
struct MatmulTiling2dEagerKernel<R: Runtime> {
struct MatmulTiling2dEagerKernel<R: Runtime, E: JitElement> {
config: Tiling2dConfig,
bounds_check_required: bool,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
impl<R: Runtime> GpuComputeShaderPhase for MatmulTiling2dEagerKernel<R> {
impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for MatmulTiling2dEagerKernel<R, E> {
fn compile(&self) -> ComputeShader {
let mut scope = gpu::Scope::root();
let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into());
let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into());
let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into());
let elem = E::gpu_elem();
assert!(
elem == gpu::Elem::Float(gpu::FloatKind::F32)
|| elem == gpu::Elem::Float(gpu::FloatKind::F64),
"Only float elements are supported."
);
let item = elem.into();
let lhs = gpu::Variable::GlobalInputArray(0, item);
let rhs = gpu::Variable::GlobalInputArray(1, item);
let out = gpu::Variable::GlobalOutputArray(0, item);
scope.write_global_custom(out);
@ -49,16 +58,14 @@ impl<R: Runtime> GpuComputeShaderPhase for MatmulTiling2dEagerKernel<R> {
.expand(&mut scope);
let lhs = InputInfo::Array {
item: gpu::Elem::Float.into(),
item,
visibility: gpu::Visibility::Read,
};
let rhs = InputInfo::Array {
item: gpu::Elem::Float.into(),
item,
visibility: gpu::Visibility::Read,
};
let out = OutputInfo::Array {
item: gpu::Elem::Float.into(),
};
let out = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![lhs, rhs],
@ -94,7 +101,7 @@ pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(
) -> JitTensor<R, E, D> {
let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config);
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), bounds_check_required);
let kernel = MatmulTiling2dEagerKernel::<R, E>::new(config.clone(), bounds_check_required);
let client = lhs.client.clone();
let lhs = match lhs.batch_swapped_with_row_col() {
@ -126,7 +133,7 @@ pub fn matmul_tiling_2d_padded<R: Runtime, E: JitElement + Element, const D: usi
out: JitTensor<R, E, D>,
config: Tiling2dConfig,
) -> JitTensor<R, E, D> {
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), false);
let kernel = MatmulTiling2dEagerKernel::<R, E>::new(config.clone(), false);
let client = lhs.client.clone();
// A tensor may need to be padded, in which case it will implicitly become contiguous

View File

@ -18,12 +18,13 @@ struct AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
_elem: PhantomData<E>,
}
struct AdaptiveAvgPool2dBackwardComputeShader {
struct AdaptiveAvgPool2dBackwardComputeShader<E> {
grad: Variable,
output: Variable,
_elem: PhantomData<E>,
}
impl AdaptiveAvgPool2dBackwardComputeShader {
impl<E: JitElement> AdaptiveAvgPool2dBackwardComputeShader<E> {
fn expand(self, scope: &mut Scope) {
let grad = self.grad;
let output = self.output;
@ -158,8 +159,9 @@ impl AdaptiveAvgPool2dBackwardComputeShader {
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let elem = E::gpu_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index * input_size);
@ -177,8 +179,9 @@ impl AdaptiveAvgPool2dBackwardComputeShader {
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let elem = E::gpu_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);
@ -213,7 +216,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
scope.write_global_custom(output);
AdaptiveAvgPool2dBackwardComputeShader { grad, output }.expand(&mut scope);
AdaptiveAvgPool2dBackwardComputeShader {
grad,
output,
_elem: PhantomData::<E>,
}
.expand(&mut scope);
let grad = InputInfo::Array {
item,

View File

@ -137,8 +137,9 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let elem = E::gpu_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index * input_size);
@ -156,8 +157,9 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let elem = E::gpu_elem();
let numerator_float = scope.create_local(elem);
let div = scope.create_local(elem);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);

View File

@ -1,6 +1,6 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
dialect::gpu::{gpu, Elem, IntKind, Scope, Variable, Visibility},
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
@ -223,17 +223,17 @@ impl AvgPool2dBackwardComputeShader {
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let signed_ih = scope.create_local(Elem::Int);
let signed_iw = scope.create_local(Elem::Int);
let signed_ih = scope.create_local(Elem::Int(IntKind::I32));
let signed_iw = scope.create_local(Elem::Int(IntKind::I32));
let signed_pool_stride_0 = scope.create_local(Elem::Int);
let signed_pool_stride_1 = scope.create_local(Elem::Int);
let signed_dilation_0 = scope.create_local(Elem::Int);
let signed_dilation_1 = scope.create_local(Elem::Int);
let signed_padding_0 = scope.create_local(Elem::Int);
let signed_padding_1 = scope.create_local(Elem::Int);
let signed_kernel_size_0 = scope.create_local(Elem::Int);
let signed_kernel_size_1 = scope.create_local(Elem::Int);
let signed_pool_stride_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_pool_stride_1 = scope.create_local(Elem::Int(IntKind::I32));
let signed_dilation_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_dilation_1 = scope.create_local(Elem::Int(IntKind::I32));
let signed_padding_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_padding_1 = scope.create_local(Elem::Int(IntKind::I32));
let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
@ -248,8 +248,8 @@ impl AvgPool2dBackwardComputeShader {
gpu!(scope, signed_ih = cast(ih));
gpu!(scope, signed_iw = cast(iw));
let kms_0 = scope.create_local(Elem::Int);
let kms_1 = scope.create_local(Elem::Int);
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
@ -257,8 +257,8 @@ impl AvgPool2dBackwardComputeShader {
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
let oh_start_tmp = scope.create_local(Elem::Int);
let ow_start_tmp = scope.create_local(Elem::Int);
let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
@ -277,8 +277,8 @@ impl AvgPool2dBackwardComputeShader {
gpu!(scope, oh_start = cast(oh_start_tmp));
gpu!(scope, ow_start = cast(ow_start_tmp));
let oh_end_tmp = scope.create_local(Elem::Int);
let ow_end_tmp = scope.create_local(Elem::Int);
let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));

View File

@ -1,6 +1,6 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility},
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
@ -94,7 +94,7 @@ impl MaxPool2dBackwardComputeShader {
gpu!(scope, index_current_tmp = iw * output_stride_3);
gpu!(scope, index_current += index_current_tmp);
let index_select = scope.create_local(Elem::Int);
let index_select = scope.create_local(Elem::Int(IntKind::I32));
let index_max = scope.create_local(Elem::UInt);
let is_max = scope.create_local(Elem::Bool);
@ -169,17 +169,17 @@ impl MaxPool2dBackwardComputeShader {
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let signed_ih = scope.create_local(Elem::Int);
let signed_iw = scope.create_local(Elem::Int);
let signed_ih = scope.create_local(Elem::Int(IntKind::I32));
let signed_iw = scope.create_local(Elem::Int(IntKind::I32));
let signed_pool_stride_0 = scope.create_local(Elem::Int);
let signed_pool_stride_1 = scope.create_local(Elem::Int);
let signed_dilation_0 = scope.create_local(Elem::Int);
let signed_dilation_1 = scope.create_local(Elem::Int);
let signed_padding_0 = scope.create_local(Elem::Int);
let signed_padding_1 = scope.create_local(Elem::Int);
let signed_kernel_size_0 = scope.create_local(Elem::Int);
let signed_kernel_size_1 = scope.create_local(Elem::Int);
let signed_pool_stride_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_pool_stride_1 = scope.create_local(Elem::Int(IntKind::I32));
let signed_dilation_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_dilation_1 = scope.create_local(Elem::Int(IntKind::I32));
let signed_padding_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_padding_1 = scope.create_local(Elem::Int(IntKind::I32));
let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32));
let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
@ -194,8 +194,8 @@ impl MaxPool2dBackwardComputeShader {
gpu!(scope, signed_ih = cast(ih));
gpu!(scope, signed_iw = cast(iw));
let kms_0 = scope.create_local(Elem::Int);
let kms_1 = scope.create_local(Elem::Int);
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
@ -203,8 +203,8 @@ impl MaxPool2dBackwardComputeShader {
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
let oh_start_tmp = scope.create_local(Elem::Int);
let ow_start_tmp = scope.create_local(Elem::Int);
let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
@ -223,8 +223,8 @@ impl MaxPool2dBackwardComputeShader {
gpu!(scope, oh_start = cast(oh_start_tmp));
gpu!(scope, ow_start = cast(ow_start_tmp));
let oh_end_tmp = scope.create_local(Elem::Int);
let ow_end_tmp = scope.create_local(Elem::Int);
let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));
@ -268,7 +268,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int));
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int(IntKind::I32)));
let grad = Variable::GlobalInputArray(1, item);
let output = Variable::GlobalOutputArray(0, item);
@ -283,7 +283,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
.expand(&mut scope);
let indices = InputInfo::Array {
item: Item::Scalar(Elem::Int),
item: Item::Scalar(Elem::Int(IntKind::I32)),
visibility: Visibility::Read,
};

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
gpu::{gpu, ComputeShader, Elem, Item, Scope, Variable, Visibility},
gpu::{gpu, ComputeShader, Elem, IntKind, Item, Scope, Variable, Visibility},
kernel::GpuComputeShaderPhase,
JitElement, Runtime,
};
@ -182,7 +182,10 @@ impl<P: PoolStrategy, R: Runtime, E: JitElement> GpuComputeShaderPhase
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let indices = if P::with_indices() {
Some(Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int)))
Some(Variable::GlobalOutputArray(
1,
Item::Scalar(Elem::Int(IntKind::I32)),
))
} else {
None
};
@ -213,7 +216,7 @@ impl<P: PoolStrategy, R: Runtime, E: JitElement> GpuComputeShaderPhase
vec![
output,
OutputInfo::Array {
item: Item::Scalar(Elem::Int),
item: Item::Scalar(Elem::Int(IntKind::I32)),
},
]
} else {

View File

@ -44,7 +44,7 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let float_random = scope.create_local(Elem::Float);
let float_random = scope.create_local(E::gpu_elem());
cast_uint_to_float(scope, int_random, float_random);
let bernoulli = scope.create_local(Elem::Bool);

View File

@ -33,10 +33,11 @@ impl<E: JitElement> Prng<E> for Normal<E> {
state_3: Variable,
output: Variable,
) {
let elem = E::gpu_elem();
let item = output.item();
let mean = args[0];
let std = args[1];
let two_pi = scope.create_with_value(2. * PI, Elem::Float);
let two_pi = scope.create_with_value(2. * PI, elem);
let t_neg = scope.create_with_value(-2.0, item);
let two: Variable = 2u32.into();
@ -55,7 +56,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let unit_0 = scope.create_local(Elem::Float);
let unit_0 = scope.create_local(elem);
cast_uint_to_float(scope, int_random, unit_0);
// Second random uniform integer
@ -68,7 +69,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let unit_1 = scope.create_local(Elem::Float);
let unit_1 = scope.create_local(elem);
cast_uint_to_float(scope, int_random, unit_1);
// Box-Muller transform

View File

@ -31,6 +31,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
state_3: Variable,
output: Variable,
) {
let elem = E::gpu_elem();
let item = output.item();
let lower_bound = args[0];
let upper_bound = args[1];
@ -50,12 +51,12 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
gpu!(scope, int_random = int_random ^ state_2);
gpu!(scope, int_random = int_random ^ state_3);
let float_random = scope.create_local(Elem::Float);
let float_scale = scope.create_local(Elem::Float);
let float_random = scope.create_local(elem);
let float_scale = scope.create_local(elem);
cast_uint_to_float(scope, int_random, float_random);
gpu!(scope, float_scale = cast(scale));
let uniform_float = scope.create_local(Elem::Float);
let uniform_float = scope.create_local(elem);
let uniform = scope.create_local(item);
gpu!(scope, uniform_float = float_random * float_scale);
gpu!(scope, uniform = cast(uniform_float));

View File

@ -1,4 +1,4 @@
use crate::{JitBackend, Runtime};
use crate::{FloatElement, IntElement, JitBackend, Runtime};
use burn_tensor::ops::ActivationOps;
impl<R: Runtime> ActivationOps<Self> for JitBackend<R> {}
impl<R: Runtime, F: FloatElement, I: IntElement> ActivationOps<Self> for JitBackend<R, F, I> {}

View File

@ -1,4 +1,4 @@
use crate::{kernel, JitBackend, Runtime};
use crate::{kernel, FloatElement, IntElement, JitBackend, Runtime};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
use burn_tensor::Reader;
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
@ -6,7 +6,12 @@ use std::ops::Range;
use super::{expand, permute};
impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
impl<R, F, I> BoolTensorOps<Self> for JitBackend<R, F, I>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
super::empty(shape, device)
}

View File

@ -3,14 +3,19 @@ use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryO
use crate::kernel::matmul::{matmul, MatmulStrategy};
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
use crate::kernel::{self, reduce};
use crate::Runtime;
use crate::{unary, JitBackend};
use crate::{FloatElement, IntElement, Runtime};
use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
use burn_tensor::{ops::FloatTensorOps, Data, Distribution, Shape};
use burn_tensor::{ElementConversion, Reader};
use std::ops::Range;
impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
impl<R, F, I> FloatTensorOps<Self> for JitBackend<R, F, I>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
fn float_from_data<const D: usize>(
data: Data<FloatElem<Self>, D>,
device: &Device<Self>,

View File

@ -1,12 +1,17 @@
use super::{expand, numeric, permute};
use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator};
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
use crate::{kernel, unary, JitBackend, Runtime};
use crate::{kernel, unary, FloatElement, IntElement, JitBackend, Runtime};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use burn_tensor::{ops::IntTensorOps, Data, Distribution, ElementConversion, Reader, Shape};
use std::ops::Range;
impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
impl<R, F, I> IntTensorOps<Self> for JitBackend<R, F, I>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
super::empty(shape, device)
}

View File

@ -1,11 +1,16 @@
use crate::{kernel, JitBackend, Runtime};
use crate::{kernel, FloatElement, IntElement, JitBackend, Runtime};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
ModuleOps,
};
use burn_tensor::ops::{FloatTensor, IntTensor};
impl<R: Runtime> ModuleOps<Self> for JitBackend<R> {
impl<R, F, I> ModuleOps<Self> for JitBackend<R, F, I>
where
R: Runtime,
F: FloatElement,
I: IntElement,
{
fn conv2d(
x: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,

View File

@ -4,9 +4,6 @@ use crate::{
};
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
/// Type alias to the runtime signed int element type.
pub type RuntimeInt<R> = <<R as Runtime>::Compiler as Compiler>::Int;
/// Runtime for the [just-in-time backend](crate::JitBackend).
pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
/// The compiler used to compile the inner representation into tokens.
@ -26,17 +23,6 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
+ Sync
+ Send;
/// A version of the runtime that supports full precision.
///
/// Note that the runtime should share all other runtime components.
/// This way, it's possible to share the same handles for both runtimes and reduce data copies to a minimum.
type FullPrecisionRuntime: Runtime<
Compiler = <Self::Compiler as Compiler>::FullPrecisionCompiler,
Device = Self::Device,
Server = Self::Server,
Channel = Self::Channel,
>;
/// Retrieve the compute client from the runtime device.
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;

View File

@ -84,7 +84,7 @@ macro_rules! testgen_jit {
use super::*;
use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test};
pub type TestBackend = JitBackend<TestRuntime>;
pub type TestBackend = JitBackend<TestRuntime, f32, i32>;
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
@ -106,7 +106,7 @@ macro_rules! testgen_jit_fusion {
use super::*;
use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor};
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime>>;
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime, f32, i32>>;
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;

View File

@ -1,13 +1,11 @@
use super::LocalArray;
use super::{shader::ComputeShader, Item, SharedMemory};
use crate::compiler::wgsl;
use crate::{FloatElement, IntElement};
use burn_jit::gpu;
use std::marker::PhantomData;
/// Wgsl Compiler.
#[derive(Clone)]
pub struct WgslCompiler<F: FloatElement, I: IntElement> {
#[derive(Clone, Default)]
pub struct WgslCompiler {
num_inputs: usize,
num_outputs: usize,
local_invocation_index: bool,
@ -21,43 +19,16 @@ pub struct WgslCompiler<F: FloatElement, I: IntElement> {
num_workgroups: bool,
shared_memories: Vec<SharedMemory>,
local_arrays: Vec<LocalArray>,
_float: PhantomData<F>,
_int: PhantomData<I>,
}
impl<F: FloatElement, I: IntElement> core::fmt::Debug for WgslCompiler<F, I> {
impl core::fmt::Debug for WgslCompiler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("WgslCompiler")
}
}
impl<F: FloatElement, I: IntElement> Default for WgslCompiler<F, I> {
fn default() -> Self {
Self {
num_inputs: 0,
num_outputs: 0,
local_invocation_index: false,
local_invocation_id: false,
global_invocation_id: false,
workgroup_id: false,
rank: false,
id: false,
stride: false,
shape: false,
num_workgroups: false,
shared_memories: Vec::default(),
local_arrays: Vec::default(),
_float: PhantomData,
_int: PhantomData,
}
}
}
impl<F: FloatElement, I: IntElement> burn_jit::Compiler for WgslCompiler<F, I> {
impl burn_jit::Compiler for WgslCompiler {
type Representation = ComputeShader;
type Float = F;
type Int = I;
type FullPrecisionCompiler = WgslCompiler<f32, i32>;
fn compile(shader: gpu::ComputeShader) -> Self::Representation {
let mut compiler = Self::default();
@ -73,7 +44,7 @@ impl<F: FloatElement, I: IntElement> burn_jit::Compiler for WgslCompiler<F, I> {
}
}
impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
impl WgslCompiler {
fn compile_shader(&mut self, mut value: gpu::ComputeShader) -> wgsl::ComputeShader {
self.num_inputs = value.inputs.len();
self.num_outputs = value.outputs.len();
@ -128,8 +99,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
fn compile_elem(value: gpu::Elem) -> wgsl::Elem {
match value {
gpu::Elem::Float => F::wgpu_elem(),
gpu::Elem::Int => I::wgpu_elem(),
gpu::Elem::Float(f) => match f {
gpu::FloatKind::F32 => wgsl::Elem::F32,
gpu::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
},
gpu::Elem::Int(i) => match i {
gpu::IntKind::I32 => wgsl::Elem::I32,
gpu::IntKind::I64 => panic!("i64 is not a valid WgpuElement"),
},
gpu::Elem::UInt => wgsl::Elem::U32,
gpu::Elem::Bool => wgsl::Elem::Bool,
}

View File

@ -45,7 +45,7 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
burn_fusion::Fusion<JitBackend<WgpuRuntime<G, F, I>>>;
burn_fusion::Fusion<JitBackend<WgpuRuntime<G>, F, I>>;
#[cfg(not(feature = "fusion"))]
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
@ -64,13 +64,13 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
///
/// You can enable the `fusion` feature flag to add that functionality, which might improve
/// performance.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G, F, I>>;
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G>, F, I>;
#[cfg(test)]
mod tests {
use super::*;
pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi, f32, i32>;
pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi>;
burn_jit::testgen_all!();
}

View File

@ -1,7 +1,7 @@
use crate::{
compiler::wgsl,
compute::{WgpuServer, WgpuStorage},
FloatElement, GraphicsApi, IntElement, WgpuDevice,
GraphicsApi, WgpuDevice,
};
use alloc::sync::Arc;
use burn_common::stub::RwLock;
@ -19,13 +19,10 @@ use wgpu::{AdapterInfo, DeviceDescriptor};
/// Runtime that uses the [wgpu] crate with the wgsl compiler.
///
/// The [graphics api](GraphicsApi), the [float element](FloatElement) and the
/// [int element](IntElement) types are passed as generic.
/// The [graphics api](GraphicsApi) type is passed as generic.
#[derive(Debug)]
pub struct WgpuRuntime<G: GraphicsApi, F: FloatElement, I: IntElement> {
pub struct WgpuRuntime<G: GraphicsApi> {
_g: PhantomData<G>,
_f: PhantomData<F>,
_i: PhantomData<I>,
}
/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime).
@ -34,9 +31,8 @@ static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>>
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Runtime for WgpuRuntime<G, F, I> {
type FullPrecisionRuntime = WgpuRuntime<G, f32, i32>;
type Compiler = wgsl::WgslCompiler<F, I>;
impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
type Compiler = wgsl::WgslCompiler;
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
type Channel = MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>;

View File

@ -71,7 +71,7 @@ fn autodiff<B: AutodiffBackend>(device: &B::Device) {
}
fn main() {
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi>, f32, i32>;
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
let device = Default::default();
inference::<MyBackend>(&device);

View File

@ -15,7 +15,7 @@ use burn::{
};
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<WgpuRuntime<G, F, I>>>
for Autodiff<JitBackend<WgpuRuntime<G>, F, I>>
{
}

View File

@ -37,7 +37,7 @@ impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
}
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G, F, I>> {
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,