mirror of https://github.com/tracel-ai/burn.git
Refactor element type to be decoupled from runtime (#1693)
This commit is contained in:
parent
67ec06d5d8
commit
ce2429eb10
|
@ -1,4 +1,4 @@
|
|||
use crate::{codegen::Compiler, tensor::JitTensor, PrecisionBridge, Runtime};
|
||||
use crate::{tensor::JitTensor, FloatElement, IntElement, PrecisionBridge, Runtime};
|
||||
use burn_tensor::backend::Backend;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::{marker::PhantomData, sync::Mutex};
|
||||
|
@ -7,16 +7,23 @@ pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
|||
|
||||
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
|
||||
#[derive(new)]
|
||||
pub struct JitBackend<R: Runtime> {
|
||||
pub struct JitBackend<R: Runtime, F: FloatElement, I: IntElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_float_elem: PhantomData<F>,
|
||||
_int_elem: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Backend for JitBackend<R> {
|
||||
impl<R, F, I> Backend for JitBackend<R, F, I>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
type Device = R::Device;
|
||||
|
||||
type FullPrecisionBridge = PrecisionBridge<R::FullPrecisionRuntime>;
|
||||
type FloatElem = <R::Compiler as Compiler>::Float;
|
||||
type IntElem = <R::Compiler as Compiler>::Int;
|
||||
type FullPrecisionBridge = PrecisionBridge<R, f32, i32>;
|
||||
type FloatElem = F;
|
||||
type IntElem = I;
|
||||
|
||||
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
|
||||
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
|
||||
|
@ -42,19 +49,19 @@ impl<R: Runtime> Backend for JitBackend<R> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> core::fmt::Debug for JitBackend<R> {
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> core::fmt::Debug for JitBackend<R, F, I> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for JitBackend<R> {
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> Clone for JitBackend<R, F, I> {
|
||||
fn clone(&self) -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Default for JitBackend<R> {
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> Default for JitBackend<R, F, I> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate::{kernel, ops::to_device, tensor::JitTensor, JitBackend, Runtime};
|
||||
use crate::{
|
||||
kernel, ops::to_device, tensor::JitTensor, FloatElement, IntElement, JitBackend, Runtime,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::BackendBridge,
|
||||
ops::{FloatElem, FloatTensor},
|
||||
|
@ -7,26 +9,31 @@ use core::marker::PhantomData;
|
|||
|
||||
/// Handle precision conversion for the jit backend.
|
||||
#[derive(Debug)]
|
||||
pub struct PrecisionBridge<R> {
|
||||
pub struct PrecisionBridge<R, F: FloatElement, I: IntElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_float_elem: PhantomData<F>,
|
||||
_int_elem: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<ROrigin, RTarget> BackendBridge<JitBackend<ROrigin>> for PrecisionBridge<RTarget>
|
||||
impl<R, FOrigin, IOrigin, FTarget, ITarget> BackendBridge<JitBackend<R, FOrigin, IOrigin>>
|
||||
for PrecisionBridge<R, FTarget, ITarget>
|
||||
where
|
||||
ROrigin: Runtime,
|
||||
RTarget:
|
||||
Runtime<Device = ROrigin::Device, Server = ROrigin::Server, Channel = ROrigin::Channel>,
|
||||
R: Runtime,
|
||||
FOrigin: FloatElement,
|
||||
IOrigin: IntElement,
|
||||
FTarget: FloatElement,
|
||||
ITarget: IntElement,
|
||||
{
|
||||
type Target = JitBackend<RTarget>;
|
||||
type Target = JitBackend<R, FTarget, ITarget>;
|
||||
|
||||
fn into_target<const D: usize>(
|
||||
tensor: FloatTensor<JitBackend<ROrigin>, D>,
|
||||
tensor: FloatTensor<JitBackend<R, FOrigin, IOrigin>, D>,
|
||||
device: Option<burn_tensor::Device<Self::Target>>,
|
||||
) -> FloatTensor<Self::Target, D> {
|
||||
let tensor = kernel::cast::<
|
||||
ROrigin,
|
||||
FloatElem<JitBackend<ROrigin>>,
|
||||
FloatElem<JitBackend<RTarget>>,
|
||||
R,
|
||||
FloatElem<JitBackend<R, FOrigin, IOrigin>>,
|
||||
FloatElem<JitBackend<R, FTarget, ITarget>>,
|
||||
D,
|
||||
>(tensor);
|
||||
|
||||
|
@ -42,12 +49,12 @@ where
|
|||
|
||||
fn from_target<const D: usize>(
|
||||
tensor: FloatTensor<Self::Target, D>,
|
||||
device: Option<burn_tensor::Device<JitBackend<ROrigin>>>,
|
||||
) -> FloatTensor<JitBackend<ROrigin>, D> {
|
||||
device: Option<burn_tensor::Device<JitBackend<R, FOrigin, IOrigin>>>,
|
||||
) -> FloatTensor<JitBackend<R, FOrigin, IOrigin>, D> {
|
||||
let tensor = kernel::cast::<
|
||||
RTarget,
|
||||
FloatElem<JitBackend<RTarget>>,
|
||||
FloatElem<JitBackend<ROrigin>>,
|
||||
R,
|
||||
FloatElem<JitBackend<R, FTarget, ITarget>>,
|
||||
FloatElem<JitBackend<R, FOrigin, IOrigin>>,
|
||||
D,
|
||||
>(tensor);
|
||||
// The line below does the backend type cast.
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use super::dialect::gpu;
|
||||
use crate::{FloatElement, IntElement};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Compiles the [gpu representation](gpu::ComputeShader) into its own representation that can be
|
||||
|
@ -7,16 +6,6 @@ use std::fmt::Display;
|
|||
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
|
||||
/// The representation for the compiled code.
|
||||
type Representation: Display;
|
||||
/// The float element type used for compilation.
|
||||
type Float: FloatElement;
|
||||
/// The int element type used for compilation.
|
||||
type Int: IntElement;
|
||||
/// The compiler that can be used to generate full precision shaders.
|
||||
type FullPrecisionCompiler: Compiler<
|
||||
Representation = Self::Representation,
|
||||
Float = f32,
|
||||
Int = i32,
|
||||
>;
|
||||
|
||||
/// Compiles the [gpu shader](gpu::ComputeShader) into the compiler's representation.
|
||||
fn compile(shader: gpu::ComputeShader) -> Self::Representation;
|
||||
|
|
|
@ -390,13 +390,13 @@ impl From<bool> for Variable {
|
|||
|
||||
impl From<i32> for Variable {
|
||||
fn from(value: i32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Int)
|
||||
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Variable {
|
||||
fn from(value: f32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Float)
|
||||
Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,11 +17,25 @@ pub enum Visibility {
|
|||
ReadWrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum FloatKind {
|
||||
F32,
|
||||
F64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum IntKind {
|
||||
I32,
|
||||
I64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum Elem {
|
||||
Float,
|
||||
Int,
|
||||
Float(FloatKind),
|
||||
Int(IntKind),
|
||||
UInt,
|
||||
Bool,
|
||||
}
|
||||
|
@ -35,8 +49,9 @@ impl From<Elem> for Item {
|
|||
impl Display for Elem {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Float => f.write_str("float"),
|
||||
Self::Int => f.write_str("int"),
|
||||
// NOTE: we'll eventually want to differentiate between int/float types
|
||||
Self::Float(_) => f.write_str("float"),
|
||||
Self::Int(_) => f.write_str("int"),
|
||||
Self::UInt => f.write_str("uint"),
|
||||
Self::Bool => f.write_str("bool"),
|
||||
}
|
||||
|
|
|
@ -312,8 +312,8 @@ fn create_scalar_handles<R: Runtime, E1: JitElement, E2: JitElement, E3: JitElem
|
|||
) -> Vec<Handle<R::Server>> {
|
||||
// It is crucial that scalars follow this order: float, int, uint
|
||||
let element_priority = |elem: Elem| match elem {
|
||||
Elem::Float => 0,
|
||||
Elem::Int => 1,
|
||||
Elem::Float(_) => 0,
|
||||
Elem::Int(_) => 1,
|
||||
Elem::UInt => 2,
|
||||
Elem::Bool => panic!("Bool scalars are not supported"),
|
||||
};
|
||||
|
|
|
@ -59,7 +59,7 @@ impl JitElement for i32 {
|
|||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
fn gpu_elem() -> gpu::Elem {
|
||||
gpu::Elem::Int
|
||||
gpu::Elem::Int(gpu::IntKind::I32)
|
||||
}
|
||||
fn maximum_value() -> Self {
|
||||
// Seems to cause problem for some GPU
|
||||
|
@ -82,7 +82,7 @@ impl JitElement for f32 {
|
|||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
fn gpu_elem() -> gpu::Elem {
|
||||
gpu::Elem::Float
|
||||
gpu::Elem::Float(gpu::FloatKind::F32)
|
||||
}
|
||||
fn maximum_value() -> Self {
|
||||
f32::MAX
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::{ElementWise, ElementWiseState};
|
||||
use crate::{
|
||||
element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, JitBackend, Runtime,
|
||||
element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, FloatElement, IntElement,
|
||||
JitBackend, Runtime,
|
||||
};
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_fusion::{client::MutexFusionClient, FusionBackend};
|
||||
|
@ -25,8 +26,13 @@ pub enum JitOptimizationState {
|
|||
ElementWise(ElementWiseState),
|
||||
}
|
||||
|
||||
impl<R: Runtime> burn_fusion::Optimization<JitBackend<R>> for JitOptimization<R> {
|
||||
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend<R>>) {
|
||||
impl<R, F, I> burn_fusion::Optimization<JitBackend<R, F, I>> for JitOptimization<R>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend<R, F, I>>) {
|
||||
match self {
|
||||
Self::ElementWise(op) => op.execute(context),
|
||||
}
|
||||
|
@ -53,7 +59,7 @@ impl<R: Runtime> burn_fusion::Optimization<JitBackend<R>> for JitOptimization<R>
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReprBackend for JitBackend<R> {
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
|
||||
type Handle = JitFusionHandle<R>;
|
||||
|
||||
fn float_tensor<const D: usize>(
|
||||
|
@ -96,7 +102,7 @@ impl<R: Runtime> ReprBackend for JitBackend<R> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> FusionBackend for JitBackend<R> {
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> FusionBackend for JitBackend<R, F, I> {
|
||||
type OptimizationState = JitOptimizationState;
|
||||
type Optimization = JitOptimization<R>;
|
||||
type FusionClient = MutexFusionClient<Self>;
|
||||
|
@ -104,7 +110,7 @@ impl<R: Runtime> FusionBackend for JitBackend<R> {
|
|||
fn optimizations(
|
||||
device: R::Device,
|
||||
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
|
||||
vec![Box::new(ElementWiseBuilder::new(device))]
|
||||
vec![Box::new(ElementWiseBuilder::<R, F, I>::new(device))]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use super::{optimization::ElementWise, CompilationPhase};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{
|
||||
|
@ -5,7 +7,7 @@ use crate::{
|
|||
},
|
||||
element::JitElement,
|
||||
fusion::{tracing::TraceBuilder, JitOptimization},
|
||||
JitBackend, Runtime,
|
||||
FloatElement, IntElement, JitBackend, Runtime,
|
||||
};
|
||||
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
|
||||
use burn_tensor::{
|
||||
|
@ -19,15 +21,22 @@ use burn_tensor::{
|
|||
};
|
||||
|
||||
/// Fused element wise operations that are normally memory bound.
|
||||
pub(crate) struct ElementWiseBuilder<R: Runtime> {
|
||||
pub(crate) struct ElementWiseBuilder<R: Runtime, F: FloatElement, I: IntElement> {
|
||||
builder: TraceBuilder,
|
||||
current_output_shape: Vec<usize>,
|
||||
status: OptimizationStatus,
|
||||
num_added: usize,
|
||||
device: R::Device,
|
||||
_float_elem: PhantomData<F>,
|
||||
_int_elem: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<R> {
|
||||
impl<R, F, I> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<R, F, I>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn register(&mut self, ops: &OperationDescription) {
|
||||
if let OptimizationStatus::Closed = self.status {
|
||||
return;
|
||||
|
@ -35,31 +44,31 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<
|
|||
|
||||
match ops {
|
||||
OperationDescription::BaseFloat(ops) => {
|
||||
if !self.register_base::<FloatElem<JitBackend<R>>>(ops) {
|
||||
if !self.register_base::<FloatElem<JitBackend<R, F, I>>>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::BaseInt(ops) => {
|
||||
if !self.register_base::<IntElem<JitBackend<R>>>(ops) {
|
||||
if !self.register_base::<IntElem<JitBackend<R, F, I>>>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::Float(ops) => {
|
||||
if !self.register_float::<FloatElem<JitBackend<R>>>(ops) {
|
||||
if !self.register_float::<FloatElem<JitBackend<R, F, I>>>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::NumericFloat(ops) => {
|
||||
if !self.register_numeric::<FloatElem<JitBackend<R>>, _>(ops) {
|
||||
if !self.register_numeric::<FloatElem<JitBackend<R, F, I>>, _>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::NumericInt(ops) => {
|
||||
if !self.register_numeric::<IntElem<JitBackend<R>>, _>(ops) {
|
||||
if !self.register_numeric::<IntElem<JitBackend<R, F, I>>, _>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
|
@ -110,14 +119,16 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ElementWiseBuilder<R> {
|
||||
pub fn new(device: Device<JitBackend<R>>) -> Self {
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> ElementWiseBuilder<R, F, I> {
|
||||
pub fn new(device: Device<JitBackend<R, F, I>>) -> Self {
|
||||
Self {
|
||||
builder: TraceBuilder::new(),
|
||||
num_added: 0,
|
||||
current_output_shape: Vec::new(),
|
||||
status: OptimizationStatus::Open,
|
||||
device,
|
||||
_float_elem: PhantomData,
|
||||
_int_elem: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::{
|
|||
codegen::dialect::gpu::WorkgroupSize,
|
||||
compute::JitAutotuneKey,
|
||||
fusion::{kernel::FusionKernel, tracing::Trace},
|
||||
JitBackend, Runtime,
|
||||
FloatElement, IntElement, JitBackend, Runtime,
|
||||
};
|
||||
use burn_common::id::IdGenerator;
|
||||
use burn_compute::client::ComputeClient;
|
||||
|
@ -66,7 +66,10 @@ impl<R: Runtime> ElementWise<R, CompilationPhase> {
|
|||
}
|
||||
|
||||
impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
||||
pub(crate) fn execute(&mut self, context: &mut Context<'_, JitBackend<R>>) {
|
||||
pub(crate) fn execute<F: FloatElement, I: IntElement>(
|
||||
&mut self,
|
||||
context: &mut Context<'_, JitBackend<R, F, I>>,
|
||||
) {
|
||||
let client = R::client(&self.device);
|
||||
|
||||
let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new(
|
||||
|
@ -81,9 +84,9 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
}
|
||||
}
|
||||
|
||||
fn run_kernel(
|
||||
fn run_kernel<F: FloatElement, I: IntElement>(
|
||||
&mut self,
|
||||
context: &mut Context<'_, JitBackend<R>>,
|
||||
context: &mut Context<'_, JitBackend<R, F, I>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
fastest_set_index: usize,
|
||||
) {
|
||||
|
@ -106,9 +109,9 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
kernel.execute();
|
||||
}
|
||||
|
||||
fn run_autotune(
|
||||
fn run_autotune<F: FloatElement, I: IntElement>(
|
||||
&mut self,
|
||||
context: &mut Context<'_, JitBackend<R>>,
|
||||
context: &mut Context<'_, JitBackend<R, F, I>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
key: JitAutotuneKey,
|
||||
) {
|
||||
|
@ -152,9 +155,9 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
}
|
||||
|
||||
/// The first output is chosen when possible, otherwise the first input is chosen.
|
||||
pub(crate) fn autotune_shape<'a>(
|
||||
pub(crate) fn autotune_shape<'a, F: FloatElement, I: IntElement>(
|
||||
&self,
|
||||
context: &mut Context<'a, JitBackend<R>>,
|
||||
context: &mut Context<'a, JitBackend<R, F, I>>,
|
||||
) -> &'a [usize] {
|
||||
let info = self.trace.running();
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ use crate::fusion::strides_dyn_rank;
|
|||
use crate::fusion::JitFusionHandle;
|
||||
use crate::gpu::ComputeShader;
|
||||
use crate::kernel::GpuComputeShaderPhase;
|
||||
use crate::FloatElement;
|
||||
use crate::IntElement;
|
||||
use crate::JitBackend;
|
||||
use crate::Runtime;
|
||||
use burn_compute::client::ComputeClient;
|
||||
|
@ -106,14 +108,19 @@ impl<R: Runtime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
|
|||
}
|
||||
|
||||
impl<R: Runtime> FusionKernel<R> {
|
||||
pub fn create<K: FusionKernelFactory<R>>(
|
||||
pub fn create<K, F, I>(
|
||||
factory: &K,
|
||||
running_info: &ExecutionInfo<'_>,
|
||||
context: &mut Context<'_, JitBackend<R>>,
|
||||
device: Device<JitBackend<R>>,
|
||||
context: &mut Context<'_, JitBackend<R, F, I>>,
|
||||
device: Device<JitBackend<R, F, I>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
stateful: bool,
|
||||
) -> ExecutableKernel<R> {
|
||||
) -> ExecutableKernel<R>
|
||||
where
|
||||
K: FusionKernelFactory<R>,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
let (handles_input, inputs_description_updated, outputs_description_updated) =
|
||||
process_inputs_outputs(
|
||||
&running_info.inputs,
|
||||
|
@ -266,16 +273,21 @@ fn register_info_tensor<R: Runtime>(
|
|||
}
|
||||
}
|
||||
|
||||
fn process_inputs_outputs<'a, R: Runtime>(
|
||||
fn process_inputs_outputs<'a, R, F, I>(
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
context: &'a mut Context<'_, JitBackend<R>>,
|
||||
context: &'a mut Context<'_, JitBackend<R, F, I>>,
|
||||
stateful: bool,
|
||||
) -> (
|
||||
Vec<JitFusionHandle<R>>,
|
||||
Vec<&'a TensorDescription>,
|
||||
Vec<&'a TensorDescription>,
|
||||
) {
|
||||
)
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
let mut inputs_description_updated = Vec::with_capacity(inputs.len());
|
||||
let mut outputs_description_updated = Vec::with_capacity(outputs.len());
|
||||
let mut handles_input = Vec::new();
|
||||
|
|
|
@ -90,14 +90,14 @@ impl TraceBuilder {
|
|||
/// Create a variable from an input [scalar](Element).
|
||||
pub fn scalar<E: Element>(&mut self, _value: &E, elem_type: gpu::Elem) -> gpu::Variable {
|
||||
match elem_type {
|
||||
gpu::Elem::Float => {
|
||||
gpu::Elem::Float(_) => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_float as u16, elem_type);
|
||||
self.scalars.num_float += 1;
|
||||
var
|
||||
}
|
||||
gpu::Elem::Int => {
|
||||
gpu::Elem::Int(_) => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_int as u16, elem_type);
|
||||
|
|
|
@ -57,9 +57,10 @@ impl Trace {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// NOTE: we might want to pass a struct including all inputs/outputs metadata instead of 3 arrays
|
||||
if self.scalars.num_float > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: gpu::Elem::Float,
|
||||
elem: gpu::Elem::Float(gpu::FloatKind::F32),
|
||||
size: self.scalars.num_float,
|
||||
})
|
||||
}
|
||||
|
@ -73,7 +74,7 @@ impl Trace {
|
|||
|
||||
if self.scalars.num_int > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: gpu::Elem::Int,
|
||||
elem: gpu::Elem::Int(gpu::IntKind::I32),
|
||||
size: self.scalars.num_int,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility},
|
||||
gpu::{gpu, ComputeShader, Elem, IntKind, Scope, Variable, Visibility},
|
||||
kernel::{self, GpuComputeShaderPhase},
|
||||
ops::{
|
||||
numeric::{empty_device, zeros_device},
|
||||
|
@ -99,8 +99,8 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let groups = Variable::GlobalScalar(6, Elem::UInt);
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int);
|
||||
let stride_1_i = scope.create_local(Elem::Int);
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
gpu!(scope, stride_0_i = cast(conv_stride_0));
|
||||
gpu!(scope, stride_1_i = cast(conv_stride_1));
|
||||
|
||||
|
@ -139,15 +139,15 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
gpu!(scope, ic_end = ic_start + ic_tmp);
|
||||
|
||||
let tmp_u = scope.create_local(Elem::UInt);
|
||||
let tmp_i = scope.create_local(Elem::Int);
|
||||
let zero_i = scope.zero(Elem::Int);
|
||||
let one_i = scope.create_with_value(1, Elem::Int);
|
||||
let tmp_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let zero_i = scope.zero(Elem::Int(IntKind::I32));
|
||||
let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32));
|
||||
|
||||
let kms_u = scope.create_local(Elem::UInt);
|
||||
let kms_0 = scope.create_local(Elem::Int);
|
||||
let kms_1 = scope.create_local(Elem::Int);
|
||||
let ih_start_tmp = scope.create_local(Elem::Int);
|
||||
let iw_start_tmp = scope.create_local(Elem::Int);
|
||||
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Elem, Scope, Variable};
|
||||
use crate::codegen::Execution;
|
||||
use crate::gpu::ComputeShader;
|
||||
use crate::gpu::{ComputeShader, IntKind};
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
|
@ -80,7 +80,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for GatherEagerKernel<R, E
|
|||
fn compile(&self) -> ComputeShader {
|
||||
let mut scope = gpu::Scope::root();
|
||||
let item_tensor = E::gpu_elem().into();
|
||||
let item_indices: gpu::Item = gpu::Elem::Int.into();
|
||||
let item_indices: gpu::Item = gpu::Elem::Int(IntKind::I32).into();
|
||||
|
||||
let tensor = gpu::Variable::GlobalInputArray(0, item_tensor);
|
||||
let indices = scope.read_array(1, item_indices);
|
||||
|
@ -103,7 +103,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for GatherEagerKernel<R, E
|
|||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let indices = InputInfo::Array {
|
||||
item: gpu::Elem::Int.into(),
|
||||
item: gpu::Elem::Int(IntKind::I32).into(),
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let out = OutputInfo::Array { item: item_tensor };
|
||||
|
|
|
@ -133,10 +133,10 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for ScatterEagerKernel<R,
|
|||
fn compile(&self) -> ComputeShader {
|
||||
let mut scope = gpu::Scope::root();
|
||||
let item_value = E::gpu_elem().into();
|
||||
let item_indices: gpu::Item = gpu::Elem::Int.into();
|
||||
let item_indices: gpu::Item = gpu::Elem::Int(gpu::IntKind::I32).into();
|
||||
|
||||
let input_output = gpu::Variable::GlobalInputArray(0, item_value);
|
||||
let indices = gpu::Variable::GlobalInputArray(1, Elem::Int.into());
|
||||
let indices = gpu::Variable::GlobalInputArray(1, Elem::Int(gpu::IntKind::I32).into());
|
||||
let value = gpu::Variable::GlobalInputArray(2, item_value);
|
||||
|
||||
scope.write_global_custom(input_output);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility},
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
|
@ -74,7 +74,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for SelectEagerKernel<R, E
|
|||
fn compile(&self) -> ComputeShader {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
let item_indices: Item = Elem::Int.into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let indices = Variable::GlobalInputArray(1, item_indices);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Branch, Elem, Item, Scope, Variable, Visibility},
|
||||
dialect::gpu::{gpu, Branch, Elem, IntKind, Item, Scope, Variable, Visibility},
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
WorkgroupLaunch,
|
||||
},
|
||||
|
@ -132,7 +132,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for SelectAssignEagerKerne
|
|||
fn compile(&self) -> ComputeShader {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
let item_indices: Item = Elem::Int.into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item);
|
||||
let value = Variable::GlobalInputArray(1, item);
|
||||
|
|
|
@ -17,16 +17,18 @@ struct InterpolateBicubicEagerKernel<R, E> {
|
|||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct InterpolateBicubicShader {
|
||||
struct InterpolateBicubicShader<E> {
|
||||
input: Variable,
|
||||
output: Variable,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl InterpolateBicubicShader {
|
||||
impl<E: JitElement> InterpolateBicubicShader<E> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
let elem = E::gpu_elem();
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -83,20 +85,20 @@ impl InterpolateBicubicShader {
|
|||
|
||||
let input_height = scope.create_local(Elem::UInt);
|
||||
let output_height = scope.create_local(Elem::UInt);
|
||||
let output_height_float = scope.create_local(Elem::Float);
|
||||
let output_height_float = scope.create_local(elem);
|
||||
|
||||
let input_width = scope.create_local(Elem::UInt);
|
||||
let output_width = scope.create_local(Elem::UInt);
|
||||
let output_width_float = scope.create_local(Elem::Float);
|
||||
let output_width_float = scope.create_local(elem);
|
||||
|
||||
let frac = scope.create_local(Elem::Float);
|
||||
let frac = scope.create_local(elem);
|
||||
let numerator = scope.create_local(Elem::UInt);
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let not_zero = scope.create_local(Elem::Bool);
|
||||
|
||||
let y_in_float = scope.create_local(Elem::Float);
|
||||
let y_in_float = scope.create_local(elem);
|
||||
let y_in = scope.create_local(Elem::UInt);
|
||||
let yw = scope.create_local(Elem::Float);
|
||||
let yw = scope.create_local(elem);
|
||||
let y_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, input_height = input_shape_2 - 1u32);
|
||||
|
@ -123,9 +125,9 @@ impl InterpolateBicubicShader {
|
|||
gpu!(scope, y_tmp = y_in + 2u32);
|
||||
let y3 = Self::min(scope, y_tmp, input_height);
|
||||
|
||||
let x_in_float = scope.create_local(Elem::Float);
|
||||
let x_in_float = scope.create_local(elem);
|
||||
let x_in = scope.create_local(Elem::UInt);
|
||||
let xw = scope.create_local(Elem::Float);
|
||||
let xw = scope.create_local(elem);
|
||||
let x_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, input_width = input_shape_3 - 1u32);
|
||||
|
@ -374,7 +376,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for InterpolateBicubicEage
|
|||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
InterpolateBicubicShader { input, output }.expand(&mut scope);
|
||||
InterpolateBicubicShader {
|
||||
input,
|
||||
output,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -17,16 +17,18 @@ struct InterpolateNearestEagerKernel<R, E> {
|
|||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct InterpolateNearestShader {
|
||||
struct InterpolateNearestShader<E> {
|
||||
input: Variable,
|
||||
output: Variable,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl InterpolateNearestShader {
|
||||
impl<E: JitElement> InterpolateNearestShader<E> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
let elem = E::gpu_elem();
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -81,11 +83,11 @@ impl InterpolateNearestShader {
|
|||
gpu!(scope, w = id / output_stride_3);
|
||||
gpu!(scope, w = w % output_shape_3);
|
||||
|
||||
let factor_float = scope.create_local(Elem::Float);
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let denominator_float = scope.create_local(Elem::Float);
|
||||
let x = scope.create_local(Elem::Float);
|
||||
let y = scope.create_local(Elem::Float);
|
||||
let factor_float = scope.create_local(elem);
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let denominator_float = scope.create_local(elem);
|
||||
let x = scope.create_local(elem);
|
||||
let y = scope.create_local(elem);
|
||||
let xu = scope.create_local(Elem::UInt);
|
||||
let yu = scope.create_local(Elem::UInt);
|
||||
|
||||
|
@ -130,7 +132,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for InterpolateNearestEage
|
|||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
InterpolateNearestShader { input, output }.expand(&mut scope);
|
||||
InterpolateNearestShader {
|
||||
input,
|
||||
output,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -17,12 +17,13 @@ struct InterpolateNearestBackwardEagerKernel<R, E> {
|
|||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct InterpolateNearestBackwardShader {
|
||||
struct InterpolateNearestBackwardShader<E> {
|
||||
out_grad: Variable,
|
||||
output: Variable,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl InterpolateNearestBackwardShader {
|
||||
impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let grad = self.out_grad;
|
||||
let output = self.output;
|
||||
|
@ -134,8 +135,9 @@ impl InterpolateNearestBackwardShader {
|
|||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let elem = E::gpu_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = input_index * output_size);
|
||||
|
@ -154,8 +156,9 @@ impl InterpolateNearestBackwardShader {
|
|||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let elem = E::gpu_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
|
@ -189,7 +192,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
|
|||
let out_grad = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
InterpolateNearestBackwardShader { out_grad, output }.expand(&mut scope);
|
||||
InterpolateNearestBackwardShader {
|
||||
out_grad,
|
||||
output,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -18,10 +18,11 @@ use std::marker::PhantomData;
|
|||
use super::simple_launch_options;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulEagerKernel<R: Runtime> {
|
||||
struct MatmulEagerKernel<R: Runtime, E: JitElement> {
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct MatmulComputeShader {
|
||||
|
@ -151,7 +152,7 @@ impl MatmulComputeShader {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> GpuComputeShaderPhase for MatmulEagerKernel<R> {
|
||||
impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for MatmulEagerKernel<R, E> {
|
||||
fn compile(&self) -> ComputeShader {
|
||||
assert_eq!(
|
||||
self.workgroup_size_x, self.workgroup_size_y,
|
||||
|
@ -159,9 +160,17 @@ impl<R: Runtime> GpuComputeShaderPhase for MatmulEagerKernel<R> {
|
|||
);
|
||||
|
||||
let mut scope = gpu::Scope::root();
|
||||
let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into());
|
||||
let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into());
|
||||
let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into());
|
||||
let elem = E::gpu_elem();
|
||||
assert!(
|
||||
elem == gpu::Elem::Float(gpu::FloatKind::F32)
|
||||
|| elem == gpu::Elem::Float(gpu::FloatKind::F64),
|
||||
"Only float elements are supported."
|
||||
);
|
||||
let item = elem.into();
|
||||
|
||||
let lhs = gpu::Variable::GlobalInputArray(0, item);
|
||||
let rhs = gpu::Variable::GlobalInputArray(1, item);
|
||||
let out = gpu::Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(out);
|
||||
|
||||
|
@ -172,16 +181,14 @@ impl<R: Runtime> GpuComputeShaderPhase for MatmulEagerKernel<R> {
|
|||
.expand(&mut scope);
|
||||
|
||||
let lhs = InputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
item,
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let rhs = InputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
item,
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let out = OutputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
};
|
||||
let out = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![lhs, rhs],
|
||||
|
@ -236,7 +243,7 @@ pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
|
|||
workgroup_size_y,
|
||||
);
|
||||
|
||||
let kernel = MatmulEagerKernel::<R>::new(workgroup_size_x, workgroup_size_y);
|
||||
let kernel = MatmulEagerKernel::<R, E>::new(workgroup_size_x, workgroup_size_y);
|
||||
|
||||
Execution::start(kernel, rhs.client)
|
||||
.inputs(&[
|
||||
|
|
|
@ -26,18 +26,27 @@ struct MatmulTiling2d<E: JitElement> {
|
|||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2dEagerKernel<R: Runtime> {
|
||||
struct MatmulTiling2dEagerKernel<R: Runtime, E: JitElement> {
|
||||
config: Tiling2dConfig,
|
||||
bounds_check_required: bool,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> GpuComputeShaderPhase for MatmulTiling2dEagerKernel<R> {
|
||||
impl<R: Runtime, E: JitElement> GpuComputeShaderPhase for MatmulTiling2dEagerKernel<R, E> {
|
||||
fn compile(&self) -> ComputeShader {
|
||||
let mut scope = gpu::Scope::root();
|
||||
let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into());
|
||||
let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into());
|
||||
let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into());
|
||||
let elem = E::gpu_elem();
|
||||
assert!(
|
||||
elem == gpu::Elem::Float(gpu::FloatKind::F32)
|
||||
|| elem == gpu::Elem::Float(gpu::FloatKind::F64),
|
||||
"Only float elements are supported."
|
||||
);
|
||||
let item = elem.into();
|
||||
|
||||
let lhs = gpu::Variable::GlobalInputArray(0, item);
|
||||
let rhs = gpu::Variable::GlobalInputArray(1, item);
|
||||
let out = gpu::Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(out);
|
||||
|
||||
|
@ -49,16 +58,14 @@ impl<R: Runtime> GpuComputeShaderPhase for MatmulTiling2dEagerKernel<R> {
|
|||
.expand(&mut scope);
|
||||
|
||||
let lhs = InputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
item,
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let rhs = InputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
item,
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let out = OutputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
};
|
||||
let out = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![lhs, rhs],
|
||||
|
@ -94,7 +101,7 @@ pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(
|
|||
) -> JitTensor<R, E, D> {
|
||||
let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config);
|
||||
|
||||
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), bounds_check_required);
|
||||
let kernel = MatmulTiling2dEagerKernel::<R, E>::new(config.clone(), bounds_check_required);
|
||||
let client = lhs.client.clone();
|
||||
|
||||
let lhs = match lhs.batch_swapped_with_row_col() {
|
||||
|
@ -126,7 +133,7 @@ pub fn matmul_tiling_2d_padded<R: Runtime, E: JitElement + Element, const D: usi
|
|||
out: JitTensor<R, E, D>,
|
||||
config: Tiling2dConfig,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), false);
|
||||
let kernel = MatmulTiling2dEagerKernel::<R, E>::new(config.clone(), false);
|
||||
let client = lhs.client.clone();
|
||||
|
||||
// A tensor may need to be padded, in which case it will implicitly become contiguous
|
||||
|
|
|
@ -18,12 +18,13 @@ struct AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
|
|||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dBackwardComputeShader {
|
||||
struct AdaptiveAvgPool2dBackwardComputeShader<E> {
|
||||
grad: Variable,
|
||||
output: Variable,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dBackwardComputeShader {
|
||||
impl<E: JitElement> AdaptiveAvgPool2dBackwardComputeShader<E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let grad = self.grad;
|
||||
let output = self.output;
|
||||
|
@ -158,8 +159,9 @@ impl AdaptiveAvgPool2dBackwardComputeShader {
|
|||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let elem = E::gpu_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index * input_size);
|
||||
|
@ -177,8 +179,9 @@ impl AdaptiveAvgPool2dBackwardComputeShader {
|
|||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let elem = E::gpu_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
|
@ -213,7 +216,12 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
|
|||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
AdaptiveAvgPool2dBackwardComputeShader { grad, output }.expand(&mut scope);
|
||||
AdaptiveAvgPool2dBackwardComputeShader {
|
||||
grad,
|
||||
output,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let grad = InputInfo::Array {
|
||||
item,
|
||||
|
|
|
@ -137,8 +137,9 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
|
|||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let elem = E::gpu_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index * input_size);
|
||||
|
@ -156,8 +157,9 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
|
|||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let elem = E::gpu_elem();
|
||||
let numerator_float = scope.create_local(elem);
|
||||
let div = scope.create_local(elem);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
dialect::gpu::{gpu, Elem, IntKind, Scope, Variable, Visibility},
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
|
@ -223,17 +223,17 @@ impl AvgPool2dBackwardComputeShader {
|
|||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let signed_ih = scope.create_local(Elem::Int);
|
||||
let signed_iw = scope.create_local(Elem::Int);
|
||||
let signed_ih = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_iw = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
let signed_pool_stride_0 = scope.create_local(Elem::Int);
|
||||
let signed_pool_stride_1 = scope.create_local(Elem::Int);
|
||||
let signed_dilation_0 = scope.create_local(Elem::Int);
|
||||
let signed_dilation_1 = scope.create_local(Elem::Int);
|
||||
let signed_padding_0 = scope.create_local(Elem::Int);
|
||||
let signed_padding_1 = scope.create_local(Elem::Int);
|
||||
let signed_kernel_size_0 = scope.create_local(Elem::Int);
|
||||
let signed_kernel_size_1 = scope.create_local(Elem::Int);
|
||||
let signed_pool_stride_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_pool_stride_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_dilation_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_dilation_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_padding_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_padding_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
|
||||
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
|
||||
|
@ -248,8 +248,8 @@ impl AvgPool2dBackwardComputeShader {
|
|||
gpu!(scope, signed_ih = cast(ih));
|
||||
gpu!(scope, signed_iw = cast(iw));
|
||||
|
||||
let kms_0 = scope.create_local(Elem::Int);
|
||||
let kms_1 = scope.create_local(Elem::Int);
|
||||
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
|
||||
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
|
||||
|
@ -257,8 +257,8 @@ impl AvgPool2dBackwardComputeShader {
|
|||
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
|
||||
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
|
||||
|
||||
let oh_start_tmp = scope.create_local(Elem::Int);
|
||||
let ow_start_tmp = scope.create_local(Elem::Int);
|
||||
let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
|
||||
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
|
||||
|
@ -277,8 +277,8 @@ impl AvgPool2dBackwardComputeShader {
|
|||
gpu!(scope, oh_start = cast(oh_start_tmp));
|
||||
gpu!(scope, ow_start = cast(ow_start_tmp));
|
||||
|
||||
let oh_end_tmp = scope.create_local(Elem::Int);
|
||||
let ow_end_tmp = scope.create_local(Elem::Int);
|
||||
let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
|
||||
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility},
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
|
@ -94,7 +94,7 @@ impl MaxPool2dBackwardComputeShader {
|
|||
gpu!(scope, index_current_tmp = iw * output_stride_3);
|
||||
gpu!(scope, index_current += index_current_tmp);
|
||||
|
||||
let index_select = scope.create_local(Elem::Int);
|
||||
let index_select = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
let index_max = scope.create_local(Elem::UInt);
|
||||
let is_max = scope.create_local(Elem::Bool);
|
||||
|
@ -169,17 +169,17 @@ impl MaxPool2dBackwardComputeShader {
|
|||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let signed_ih = scope.create_local(Elem::Int);
|
||||
let signed_iw = scope.create_local(Elem::Int);
|
||||
let signed_ih = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_iw = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
let signed_pool_stride_0 = scope.create_local(Elem::Int);
|
||||
let signed_pool_stride_1 = scope.create_local(Elem::Int);
|
||||
let signed_dilation_0 = scope.create_local(Elem::Int);
|
||||
let signed_dilation_1 = scope.create_local(Elem::Int);
|
||||
let signed_padding_0 = scope.create_local(Elem::Int);
|
||||
let signed_padding_1 = scope.create_local(Elem::Int);
|
||||
let signed_kernel_size_0 = scope.create_local(Elem::Int);
|
||||
let signed_kernel_size_1 = scope.create_local(Elem::Int);
|
||||
let signed_pool_stride_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_pool_stride_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_dilation_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_dilation_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_padding_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_padding_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0));
|
||||
gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1));
|
||||
|
@ -194,8 +194,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
gpu!(scope, signed_ih = cast(ih));
|
||||
gpu!(scope, signed_iw = cast(iw));
|
||||
|
||||
let kms_0 = scope.create_local(Elem::Int);
|
||||
let kms_1 = scope.create_local(Elem::Int);
|
||||
let kms_0 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let kms_1 = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0);
|
||||
gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0);
|
||||
|
@ -203,8 +203,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1);
|
||||
gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1);
|
||||
|
||||
let oh_start_tmp = scope.create_local(Elem::Int);
|
||||
let ow_start_tmp = scope.create_local(Elem::Int);
|
||||
let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0);
|
||||
gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0);
|
||||
|
@ -223,8 +223,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
gpu!(scope, oh_start = cast(oh_start_tmp));
|
||||
gpu!(scope, ow_start = cast(ow_start_tmp));
|
||||
|
||||
let oh_end_tmp = scope.create_local(Elem::Int);
|
||||
let ow_end_tmp = scope.create_local(Elem::Int);
|
||||
let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
||||
gpu!(scope, oh_end_tmp = max(kms_0, 0i32));
|
||||
gpu!(scope, ow_end_tmp = max(kms_1, 0i32));
|
||||
|
@ -268,7 +268,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
|
|||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int));
|
||||
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int(IntKind::I32)));
|
||||
let grad = Variable::GlobalInputArray(1, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
|
@ -283,7 +283,7 @@ impl<R: Runtime, E: JitElement> GpuComputeShaderPhase
|
|||
.expand(&mut scope);
|
||||
|
||||
let indices = InputInfo::Array {
|
||||
item: Item::Scalar(Elem::Int),
|
||||
item: Item::Scalar(Elem::Int(IntKind::I32)),
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo},
|
||||
gpu::{gpu, ComputeShader, Elem, Item, Scope, Variable, Visibility},
|
||||
gpu::{gpu, ComputeShader, Elem, IntKind, Item, Scope, Variable, Visibility},
|
||||
kernel::GpuComputeShaderPhase,
|
||||
JitElement, Runtime,
|
||||
};
|
||||
|
@ -182,7 +182,10 @@ impl<P: PoolStrategy, R: Runtime, E: JitElement> GpuComputeShaderPhase
|
|||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let indices = if P::with_indices() {
|
||||
Some(Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int)))
|
||||
Some(Variable::GlobalOutputArray(
|
||||
1,
|
||||
Item::Scalar(Elem::Int(IntKind::I32)),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
@ -213,7 +216,7 @@ impl<P: PoolStrategy, R: Runtime, E: JitElement> GpuComputeShaderPhase
|
|||
vec![
|
||||
output,
|
||||
OutputInfo::Array {
|
||||
item: Item::Scalar(Elem::Int),
|
||||
item: Item::Scalar(Elem::Int(IntKind::I32)),
|
||||
},
|
||||
]
|
||||
} else {
|
||||
|
|
|
@ -44,7 +44,7 @@ impl<E: JitElement> Prng<E> for Bernoulli<E> {
|
|||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let float_random = scope.create_local(Elem::Float);
|
||||
let float_random = scope.create_local(E::gpu_elem());
|
||||
cast_uint_to_float(scope, int_random, float_random);
|
||||
|
||||
let bernoulli = scope.create_local(Elem::Bool);
|
||||
|
|
|
@ -33,10 +33,11 @@ impl<E: JitElement> Prng<E> for Normal<E> {
|
|||
state_3: Variable,
|
||||
output: Variable,
|
||||
) {
|
||||
let elem = E::gpu_elem();
|
||||
let item = output.item();
|
||||
let mean = args[0];
|
||||
let std = args[1];
|
||||
let two_pi = scope.create_with_value(2. * PI, Elem::Float);
|
||||
let two_pi = scope.create_with_value(2. * PI, elem);
|
||||
let t_neg = scope.create_with_value(-2.0, item);
|
||||
let two: Variable = 2u32.into();
|
||||
|
||||
|
@ -55,7 +56,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
|
|||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let unit_0 = scope.create_local(Elem::Float);
|
||||
let unit_0 = scope.create_local(elem);
|
||||
cast_uint_to_float(scope, int_random, unit_0);
|
||||
|
||||
// Second random uniform integer
|
||||
|
@ -68,7 +69,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
|
|||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let unit_1 = scope.create_local(Elem::Float);
|
||||
let unit_1 = scope.create_local(elem);
|
||||
cast_uint_to_float(scope, int_random, unit_1);
|
||||
|
||||
// Box-Muller transform
|
||||
|
|
|
@ -31,6 +31,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
|
|||
state_3: Variable,
|
||||
output: Variable,
|
||||
) {
|
||||
let elem = E::gpu_elem();
|
||||
let item = output.item();
|
||||
let lower_bound = args[0];
|
||||
let upper_bound = args[1];
|
||||
|
@ -50,12 +51,12 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
|
|||
gpu!(scope, int_random = int_random ^ state_2);
|
||||
gpu!(scope, int_random = int_random ^ state_3);
|
||||
|
||||
let float_random = scope.create_local(Elem::Float);
|
||||
let float_scale = scope.create_local(Elem::Float);
|
||||
let float_random = scope.create_local(elem);
|
||||
let float_scale = scope.create_local(elem);
|
||||
cast_uint_to_float(scope, int_random, float_random);
|
||||
gpu!(scope, float_scale = cast(scale));
|
||||
|
||||
let uniform_float = scope.create_local(Elem::Float);
|
||||
let uniform_float = scope.create_local(elem);
|
||||
let uniform = scope.create_local(item);
|
||||
gpu!(scope, uniform_float = float_random * float_scale);
|
||||
gpu!(scope, uniform = cast(uniform_float));
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{JitBackend, Runtime};
|
||||
use crate::{FloatElement, IntElement, JitBackend, Runtime};
|
||||
use burn_tensor::ops::ActivationOps;
|
||||
|
||||
impl<R: Runtime> ActivationOps<Self> for JitBackend<R> {}
|
||||
impl<R: Runtime, F: FloatElement, I: IntElement> ActivationOps<Self> for JitBackend<R, F, I> {}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{kernel, JitBackend, Runtime};
|
||||
use crate::{kernel, FloatElement, IntElement, JitBackend, Runtime};
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
|
||||
use burn_tensor::Reader;
|
||||
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
|
||||
|
@ -6,7 +6,12 @@ use std::ops::Range;
|
|||
|
||||
use super::{expand, permute};
|
||||
|
||||
impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
|
||||
impl<R, F, I> BoolTensorOps<Self> for JitBackend<R, F, I>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
super::empty(shape, device)
|
||||
}
|
||||
|
|
|
@ -3,14 +3,19 @@ use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryO
|
|||
use crate::kernel::matmul::{matmul, MatmulStrategy};
|
||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
use crate::kernel::{self, reduce};
|
||||
use crate::Runtime;
|
||||
use crate::{unary, JitBackend};
|
||||
use crate::{FloatElement, IntElement, Runtime};
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
|
||||
use burn_tensor::{ops::FloatTensorOps, Data, Distribution, Shape};
|
||||
use burn_tensor::{ElementConversion, Reader};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
||||
impl<R, F, I> FloatTensorOps<Self> for JitBackend<R, F, I>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn float_from_data<const D: usize>(
|
||||
data: Data<FloatElem<Self>, D>,
|
||||
device: &Device<Self>,
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
use super::{expand, numeric, permute};
|
||||
use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator};
|
||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
use crate::{kernel, unary, JitBackend, Runtime};
|
||||
use crate::{kernel, unary, FloatElement, IntElement, JitBackend, Runtime};
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use burn_tensor::{ops::IntTensorOps, Data, Distribution, ElementConversion, Reader, Shape};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
|
||||
impl<R, F, I> IntTensorOps<Self> for JitBackend<R, F, I>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
super::empty(shape, device)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
use crate::{kernel, JitBackend, Runtime};
|
||||
use crate::{kernel, FloatElement, IntElement, JitBackend, Runtime};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
|
||||
ModuleOps,
|
||||
};
|
||||
use burn_tensor::ops::{FloatTensor, IntTensor};
|
||||
|
||||
impl<R: Runtime> ModuleOps<Self> for JitBackend<R> {
|
||||
impl<R, F, I> ModuleOps<Self> for JitBackend<R, F, I>
|
||||
where
|
||||
R: Runtime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn conv2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
|
|
|
@ -4,9 +4,6 @@ use crate::{
|
|||
};
|
||||
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
|
||||
|
||||
/// Type alias to the runtime signed int element type.
|
||||
pub type RuntimeInt<R> = <<R as Runtime>::Compiler as Compiler>::Int;
|
||||
|
||||
/// Runtime for the [just-in-time backend](crate::JitBackend).
|
||||
pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
|
||||
/// The compiler used to compile the inner representation into tokens.
|
||||
|
@ -26,17 +23,6 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
|
|||
+ Sync
|
||||
+ Send;
|
||||
|
||||
/// A version of the runtime that supports full precision.
|
||||
///
|
||||
/// Note that the runtime should share all other runtime components.
|
||||
/// This way, it's possible to share the same handles for both runtimes and reduce data copies to a minimum.
|
||||
type FullPrecisionRuntime: Runtime<
|
||||
Compiler = <Self::Compiler as Compiler>::FullPrecisionCompiler,
|
||||
Device = Self::Device,
|
||||
Server = Self::Server,
|
||||
Channel = Self::Channel,
|
||||
>;
|
||||
|
||||
/// Retrieve the compute client from the runtime device.
|
||||
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ macro_rules! testgen_jit {
|
|||
use super::*;
|
||||
use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test};
|
||||
|
||||
pub type TestBackend = JitBackend<TestRuntime>;
|
||||
pub type TestBackend = JitBackend<TestRuntime, f32, i32>;
|
||||
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
|
@ -106,7 +106,7 @@ macro_rules! testgen_jit_fusion {
|
|||
use super::*;
|
||||
use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor};
|
||||
|
||||
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime>>;
|
||||
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime, f32, i32>>;
|
||||
pub type ReferenceBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
use super::LocalArray;
|
||||
use super::{shader::ComputeShader, Item, SharedMemory};
|
||||
use crate::compiler::wgsl;
|
||||
use crate::{FloatElement, IntElement};
|
||||
use burn_jit::gpu;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Wgsl Compiler.
|
||||
#[derive(Clone)]
|
||||
pub struct WgslCompiler<F: FloatElement, I: IntElement> {
|
||||
#[derive(Clone, Default)]
|
||||
pub struct WgslCompiler {
|
||||
num_inputs: usize,
|
||||
num_outputs: usize,
|
||||
local_invocation_index: bool,
|
||||
|
@ -21,43 +19,16 @@ pub struct WgslCompiler<F: FloatElement, I: IntElement> {
|
|||
num_workgroups: bool,
|
||||
shared_memories: Vec<SharedMemory>,
|
||||
local_arrays: Vec<LocalArray>,
|
||||
_float: PhantomData<F>,
|
||||
_int: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<F: FloatElement, I: IntElement> core::fmt::Debug for WgslCompiler<F, I> {
|
||||
impl core::fmt::Debug for WgslCompiler {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("WgslCompiler")
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: FloatElement, I: IntElement> Default for WgslCompiler<F, I> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_inputs: 0,
|
||||
num_outputs: 0,
|
||||
local_invocation_index: false,
|
||||
local_invocation_id: false,
|
||||
global_invocation_id: false,
|
||||
workgroup_id: false,
|
||||
rank: false,
|
||||
id: false,
|
||||
stride: false,
|
||||
shape: false,
|
||||
num_workgroups: false,
|
||||
shared_memories: Vec::default(),
|
||||
local_arrays: Vec::default(),
|
||||
_float: PhantomData,
|
||||
_int: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: FloatElement, I: IntElement> burn_jit::Compiler for WgslCompiler<F, I> {
|
||||
impl burn_jit::Compiler for WgslCompiler {
|
||||
type Representation = ComputeShader;
|
||||
type Float = F;
|
||||
type Int = I;
|
||||
type FullPrecisionCompiler = WgslCompiler<f32, i32>;
|
||||
|
||||
fn compile(shader: gpu::ComputeShader) -> Self::Representation {
|
||||
let mut compiler = Self::default();
|
||||
|
@ -73,7 +44,7 @@ impl<F: FloatElement, I: IntElement> burn_jit::Compiler for WgslCompiler<F, I> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
||||
impl WgslCompiler {
|
||||
fn compile_shader(&mut self, mut value: gpu::ComputeShader) -> wgsl::ComputeShader {
|
||||
self.num_inputs = value.inputs.len();
|
||||
self.num_outputs = value.outputs.len();
|
||||
|
@ -128,8 +99,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
|
||||
fn compile_elem(value: gpu::Elem) -> wgsl::Elem {
|
||||
match value {
|
||||
gpu::Elem::Float => F::wgpu_elem(),
|
||||
gpu::Elem::Int => I::wgpu_elem(),
|
||||
gpu::Elem::Float(f) => match f {
|
||||
gpu::FloatKind::F32 => wgsl::Elem::F32,
|
||||
gpu::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
|
||||
},
|
||||
gpu::Elem::Int(i) => match i {
|
||||
gpu::IntKind::I32 => wgsl::Elem::I32,
|
||||
gpu::IntKind::I64 => panic!("i64 is not a valid WgpuElement"),
|
||||
},
|
||||
gpu::Elem::UInt => wgsl::Elem::U32,
|
||||
gpu::Elem::Bool => wgsl::Elem::Bool,
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
|
|||
/// You can disable the `fusion` feature flag to remove that functionality, which might be
|
||||
/// necessary on `wasm` for now.
|
||||
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
|
||||
burn_fusion::Fusion<JitBackend<WgpuRuntime<G, F, I>>>;
|
||||
burn_fusion::Fusion<JitBackend<WgpuRuntime<G>, F, I>>;
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
|
||||
|
@ -64,13 +64,13 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
|
|||
///
|
||||
/// You can enable the `fusion` feature flag to add that functionality, which might improve
|
||||
/// performance.
|
||||
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G, F, I>>;
|
||||
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G>, F, I>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi, f32, i32>;
|
||||
pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi>;
|
||||
|
||||
burn_jit::testgen_all!();
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compiler::wgsl,
|
||||
compute::{WgpuServer, WgpuStorage},
|
||||
FloatElement, GraphicsApi, IntElement, WgpuDevice,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
use alloc::sync::Arc;
|
||||
use burn_common::stub::RwLock;
|
||||
|
@ -19,13 +19,10 @@ use wgpu::{AdapterInfo, DeviceDescriptor};
|
|||
|
||||
/// Runtime that uses the [wgpu] crate with the wgsl compiler.
|
||||
///
|
||||
/// The [graphics api](GraphicsApi), the [float element](FloatElement) and the
|
||||
/// [int element](IntElement) types are passed as generic.
|
||||
/// The [graphics api](GraphicsApi) type is passed as generic.
|
||||
#[derive(Debug)]
|
||||
pub struct WgpuRuntime<G: GraphicsApi, F: FloatElement, I: IntElement> {
|
||||
pub struct WgpuRuntime<G: GraphicsApi> {
|
||||
_g: PhantomData<G>,
|
||||
_f: PhantomData<F>,
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime).
|
||||
|
@ -34,9 +31,8 @@ static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>>
|
|||
|
||||
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Runtime for WgpuRuntime<G, F, I> {
|
||||
type FullPrecisionRuntime = WgpuRuntime<G, f32, i32>;
|
||||
type Compiler = wgsl::WgslCompiler<F, I>;
|
||||
impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
|
||||
type Compiler = wgsl::WgslCompiler;
|
||||
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
|
||||
type Channel = MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>;
|
||||
|
|
|
@ -71,7 +71,7 @@ fn autodiff<B: AutodiffBackend>(device: &B::Device) {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
|
||||
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi>, f32, i32>;
|
||||
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
|
||||
let device = Default::default();
|
||||
inference::<MyBackend>(&device);
|
||||
|
|
|
@ -15,7 +15,7 @@ use burn::{
|
|||
};
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
|
||||
for Autodiff<JitBackend<WgpuRuntime<G, F, I>>>
|
||||
for Autodiff<JitBackend<WgpuRuntime<G>, F, I>>
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
|
|||
}
|
||||
|
||||
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G, F, I>> {
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
|
||||
fn fused_matmul_add_relu<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
|
|
Loading…
Reference in New Issue