mirror of https://github.com/tracel-ai/burn.git
JIT: Fix min & max values (#1429)
* real min and max values * fix * fmt
This commit is contained in:
parent
c7d4c23f97
commit
9eecc713a4
|
@ -15,6 +15,10 @@ where
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||
/// Element representation for `gpu`.
|
||||
fn gpu_elem() -> gpu::Elem;
|
||||
/// Highest possible value
|
||||
fn maximum_value() -> Self;
|
||||
/// Lowest possible value
|
||||
fn minimum_value() -> Self;
|
||||
}
|
||||
|
||||
/// The float element type for the jit backend.
|
||||
|
@ -36,6 +40,12 @@ impl JitElement for u32 {
|
|||
fn gpu_elem() -> gpu::Elem {
|
||||
gpu::Elem::UInt
|
||||
}
|
||||
fn maximum_value() -> Self {
|
||||
u32::MAX
|
||||
}
|
||||
fn minimum_value() -> Self {
|
||||
u32::MIN
|
||||
}
|
||||
}
|
||||
|
||||
impl JitElement for i32 {
|
||||
|
@ -51,6 +61,12 @@ impl JitElement for i32 {
|
|||
fn gpu_elem() -> gpu::Elem {
|
||||
gpu::Elem::Int
|
||||
}
|
||||
fn maximum_value() -> Self {
|
||||
i32::MAX
|
||||
}
|
||||
fn minimum_value() -> Self {
|
||||
i32::MIN
|
||||
}
|
||||
}
|
||||
|
||||
impl JitElement for f32 {
|
||||
|
@ -63,10 +79,15 @@ impl JitElement for f32 {
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
|
||||
fn gpu_elem() -> gpu::Elem {
|
||||
gpu::Elem::Float
|
||||
}
|
||||
fn maximum_value() -> Self {
|
||||
f32::MAX
|
||||
}
|
||||
fn minimum_value() -> Self {
|
||||
f32::MIN
|
||||
}
|
||||
}
|
||||
|
||||
impl FloatElement for f32 {}
|
||||
|
|
|
@ -28,14 +28,15 @@ struct MaxPool2dWithIndicesEagerKernel<R: Runtime, E: JitElement> {
|
|||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct MaxPool2dComputeShader {
|
||||
struct MaxPool2dComputeShader<E: JitElement> {
|
||||
x: Variable,
|
||||
output: Variable,
|
||||
kernel_size: [usize; 2],
|
||||
indices: Option<Variable>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl MaxPool2dComputeShader {
|
||||
impl<E: JitElement> MaxPool2dComputeShader<E> {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let x = self.x;
|
||||
let output = self.output;
|
||||
|
@ -121,9 +122,12 @@ impl MaxPool2dComputeShader {
|
|||
let index_input_4 = scope.create_local(Elem::UInt);
|
||||
|
||||
let is_max = scope.create_local(Elem::Bool);
|
||||
let max_val = scope.create_local(x.item());
|
||||
let max_index = self.indices.map(|_| scope.create_local(Elem::UInt));
|
||||
gpu!(scope, max_val = cast(-32767.0));
|
||||
|
||||
let max_val = scope.create_local(x.item());
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), x.item().elem());
|
||||
gpu!(scope, max_val = max_initial);
|
||||
|
||||
(0..kernel_size_0).for_each(|kh| {
|
||||
gpu!(scope, ih = oh * pool_stride_0);
|
||||
|
@ -206,6 +210,7 @@ impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dEagerKernel<R,
|
|||
output,
|
||||
kernel_size: self.kernel_size,
|
||||
indices: None,
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
|
@ -256,6 +261,7 @@ impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dWithIndicesEage
|
|||
output,
|
||||
kernel_size: self.kernel_size,
|
||||
indices: Some(indices),
|
||||
_elem: PhantomData::<E>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Elem, Item, Scope, Variable};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{gpu, Elem, Item, Scope, Variable},
|
||||
JitElement,
|
||||
};
|
||||
|
||||
use super::ReduceDimAlgorithm;
|
||||
|
||||
pub(crate) struct ArgMax;
|
||||
|
||||
impl ReduceDimAlgorithm for ArgMax {
|
||||
impl<E: JitElement> ReduceDimAlgorithm<E> for ArgMax {
|
||||
type Accumulator = (Variable, Variable);
|
||||
|
||||
fn initialize_naive(
|
||||
|
@ -12,9 +15,12 @@ impl ReduceDimAlgorithm for ArgMax {
|
|||
input_item: Item,
|
||||
_output_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let max = scope.create_local(input_item);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, max = cast(-32767.0));
|
||||
let max = scope.create_local(input_item);
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
|
||||
gpu!(scope, max = max_initial);
|
||||
|
||||
(max, index)
|
||||
}
|
||||
|
||||
|
@ -50,8 +56,8 @@ impl ReduceDimAlgorithm for ArgMax {
|
|||
) -> Self::Accumulator {
|
||||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
let max = scope.create_local(input_item);
|
||||
gpu!(scope, max = cast(-32767.0));
|
||||
|
||||
let max = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
|
||||
gpu!(scope, value_shared_memory[write_position] = max);
|
||||
(value_shared_memory, index_shared_memory)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Elem, Item, Scope, Variable};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{gpu, Elem, Item, Scope, Variable},
|
||||
JitElement,
|
||||
};
|
||||
|
||||
use super::ReduceDimAlgorithm;
|
||||
|
||||
pub(crate) struct ArgMin;
|
||||
|
||||
impl ReduceDimAlgorithm for ArgMin {
|
||||
impl<E: JitElement> ReduceDimAlgorithm<E> for ArgMin {
|
||||
type Accumulator = (Variable, Variable);
|
||||
|
||||
fn initialize_naive(
|
||||
|
@ -12,9 +15,12 @@ impl ReduceDimAlgorithm for ArgMin {
|
|||
input_item: Item,
|
||||
_output_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let min = scope.create_local(input_item);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, min = cast(32767.0));
|
||||
let min = scope.create_local(input_item);
|
||||
let min_initial =
|
||||
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
|
||||
gpu!(scope, min = min_initial);
|
||||
|
||||
(min, index)
|
||||
}
|
||||
|
||||
|
@ -50,8 +56,8 @@ impl ReduceDimAlgorithm for ArgMin {
|
|||
) -> Self::Accumulator {
|
||||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
let min = scope.create_local(input_item);
|
||||
gpu!(scope, min = cast(32767.0));
|
||||
|
||||
let min = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
|
||||
gpu!(scope, value_shared_memory[write_position] = min);
|
||||
(value_shared_memory, index_shared_memory)
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
|||
use super::{reduce_dim_naive, reduce_dim_shared, ArgMax, ArgMin, MeanDim, SumDim};
|
||||
|
||||
/// Specifies the reduce dim algorithm in use
|
||||
pub trait ReduceDimAlgorithm: Send + Sync + 'static {
|
||||
pub trait ReduceDimAlgorithm<E: JitElement>: Send + Sync + 'static {
|
||||
/// The reduction accumulator
|
||||
type Accumulator: Copy;
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Item, Scope, Variable};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{gpu, Item, Scope, Variable},
|
||||
JitElement,
|
||||
};
|
||||
|
||||
use super::ReduceDimAlgorithm;
|
||||
|
||||
pub(crate) struct MeanDim;
|
||||
|
||||
impl ReduceDimAlgorithm for MeanDim {
|
||||
impl<E: JitElement> ReduceDimAlgorithm<E> for MeanDim {
|
||||
type Accumulator = Variable;
|
||||
|
||||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
||||
|
|
|
@ -14,16 +14,17 @@ use crate::{
|
|||
|
||||
use super::ReduceDimAlgorithm;
|
||||
|
||||
pub(crate) struct NaiveReduceDimComputeShader<RD: ReduceDimAlgorithm> {
|
||||
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimAlgorithm<E>> {
|
||||
tensor: Variable,
|
||||
dim: usize,
|
||||
output: Variable,
|
||||
_reduce_dim: PhantomData<RD>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct NaiveReduceDimEagerKernel<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -35,7 +36,7 @@ pub(crate) struct NaiveReduceDimEagerKernel<
|
|||
_elem_out: PhantomData<EO>,
|
||||
}
|
||||
|
||||
impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
|
||||
impl<RD: ReduceDimAlgorithm<EI>, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
|
||||
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
|
||||
{
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
|
@ -51,6 +52,7 @@ impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement> Dynamic
|
|||
dim: self.dim,
|
||||
output,
|
||||
_reduce_dim: PhantomData::<RD>,
|
||||
_elem: PhantomData::<EI>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
|
@ -80,7 +82,7 @@ impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement> Dynamic
|
|||
}
|
||||
}
|
||||
|
||||
impl<RD: ReduceDimAlgorithm> NaiveReduceDimComputeShader<RD> {
|
||||
impl<E: JitElement, RD: ReduceDimAlgorithm<E>> NaiveReduceDimComputeShader<E, RD> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let tensor = self.tensor;
|
||||
let dim: Variable = self.dim.into();
|
||||
|
@ -140,7 +142,7 @@ impl<RD: ReduceDimAlgorithm> NaiveReduceDimComputeShader<RD> {
|
|||
|
||||
/// Executes the naive kernel for reduce dim
|
||||
pub fn reduce_dim_naive<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
|
|
@ -17,7 +17,7 @@ use crate::{
|
|||
|
||||
use super::ReduceDimAlgorithm;
|
||||
|
||||
pub(crate) struct SharedReduceDimComputeShader<RD: ReduceDimAlgorithm> {
|
||||
pub(crate) struct SharedReduceDimComputeShader<E: JitElement, RD: ReduceDimAlgorithm<E>> {
|
||||
tensor: Variable,
|
||||
dim: usize,
|
||||
shared_memory_size: usize,
|
||||
|
@ -25,11 +25,12 @@ pub(crate) struct SharedReduceDimComputeShader<RD: ReduceDimAlgorithm> {
|
|||
output: Variable,
|
||||
divisible_shape: bool,
|
||||
_reduce_dim: PhantomData<RD>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct SharedReduceDimEagerKernel<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -45,7 +46,7 @@ pub(crate) struct SharedReduceDimEagerKernel<
|
|||
_elem_out: PhantomData<EO>,
|
||||
}
|
||||
|
||||
impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
|
||||
impl<RD: ReduceDimAlgorithm<EI>, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
|
||||
for SharedReduceDimEagerKernel<RD, R, EI, EO>
|
||||
{
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
|
@ -65,6 +66,7 @@ impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement> Dynamic
|
|||
output,
|
||||
divisible_shape: self.divisible_shape,
|
||||
_reduce_dim: PhantomData::<RD>,
|
||||
_elem: PhantomData::<EI>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
|
@ -106,7 +108,7 @@ impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement> Dynamic
|
|||
}
|
||||
}
|
||||
|
||||
impl<RD: ReduceDimAlgorithm> SharedReduceDimComputeShader<RD> {
|
||||
impl<E: JitElement, RD: ReduceDimAlgorithm<E>> SharedReduceDimComputeShader<E, RD> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let tensor = self.tensor;
|
||||
let output = self.output;
|
||||
|
@ -231,7 +233,7 @@ impl<RD: ReduceDimAlgorithm> SharedReduceDimComputeShader<RD> {
|
|||
|
||||
/// Executes the shared memory kernel for reduce dim
|
||||
pub fn reduce_dim_shared<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Item, Scope, Variable};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{gpu, Item, Scope, Variable},
|
||||
JitElement,
|
||||
};
|
||||
|
||||
use super::ReduceDimAlgorithm;
|
||||
|
||||
pub(crate) struct SumDim;
|
||||
|
||||
impl ReduceDimAlgorithm for SumDim {
|
||||
impl<E: JitElement> ReduceDimAlgorithm<E> for SumDim {
|
||||
type Accumulator = Variable;
|
||||
|
||||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
||||
|
|
|
@ -21,7 +21,7 @@ use super::ReduceAutotuneKey;
|
|||
/// Autotune key is given by concatenating the closest upper power of 2 of
|
||||
/// dim to reduce, and product of others
|
||||
pub(crate) struct ReduceDimAutotuneOperationSet<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -33,7 +33,7 @@ pub(crate) struct ReduceDimAutotuneOperationSet<
|
|||
reduce_dim: usize,
|
||||
_algorithm: PhantomData<RD>,
|
||||
}
|
||||
impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement, const D: usize>
|
||||
impl<RD: ReduceDimAlgorithm<EI>, R: Runtime, EI: JitElement, EO: JitElement, const D: usize>
|
||||
ReduceDimAutotuneOperationSet<RD, R, EI, EO, D>
|
||||
{
|
||||
fn new(input: JitTensor<R, EI, D>, output: JitTensor<R, EO, D>, reduce_dim: usize) -> Self {
|
||||
|
@ -51,7 +51,7 @@ impl<RD: ReduceDimAlgorithm, R: Runtime, EI: JitElement, EO: JitElement, const D
|
|||
}
|
||||
}
|
||||
|
||||
impl<RD: ReduceDimAlgorithm, R, EI, EO, const D: usize> AutotuneOperationSet<JitAutotuneKey>
|
||||
impl<RD: ReduceDimAlgorithm<EI>, R, EI, EO, const D: usize> AutotuneOperationSet<JitAutotuneKey>
|
||||
for ReduceDimAutotuneOperationSet<RD, R, EI, EO, D>
|
||||
where
|
||||
R: Runtime,
|
||||
|
@ -105,7 +105,7 @@ where
|
|||
|
||||
/// Executes autotune on reduce_dim operation
|
||||
pub(crate) fn reduce_dim_autotune<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement + Element,
|
||||
EO: JitElement + Element,
|
||||
|
@ -132,7 +132,7 @@ pub(crate) fn reduce_dim_autotune<
|
|||
#[derive(new)]
|
||||
// Probably better on balanced tensor shapes
|
||||
pub(crate) struct ReduceDimNaiveAutotune<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -146,7 +146,7 @@ pub(crate) struct ReduceDimNaiveAutotune<
|
|||
|
||||
impl<RD, R, EI, EO, const D: usize> AutotuneOperation for ReduceDimNaiveAutotune<RD, R, EI, EO, D>
|
||||
where
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -169,7 +169,7 @@ where
|
|||
#[derive(new)]
|
||||
// Probably better on tensors large along reduce dim
|
||||
pub(crate) struct ReduceDimSharedAutotune<
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -183,7 +183,7 @@ pub(crate) struct ReduceDimSharedAutotune<
|
|||
|
||||
impl<RD, R, EI, EO, const D: usize> AutotuneOperation for ReduceDimSharedAutotune<RD, R, EI, EO, D>
|
||||
where
|
||||
RD: ReduceDimAlgorithm,
|
||||
RD: ReduceDimAlgorithm<EI>,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
|
|
@ -236,4 +236,64 @@ mod reduction {
|
|||
|
||||
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_argmax_shared_memory_extreme_values_float() {
|
||||
let data: Data<f32, 1> = Data::from([-999999., -999997., -999998.]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data, &Default::default());
|
||||
|
||||
let val_shared =
|
||||
Tensor::<TestBackend, 1, Int>::from_primitive(argmax::<TestRuntime, f32, i32, 1>(
|
||||
tensor.into_primitive(),
|
||||
0,
|
||||
ReduceStrategy::SharedMemory,
|
||||
));
|
||||
|
||||
assert_eq!(1, val_shared.into_data().value[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_argmin_shared_memory_extreme_values_float() {
|
||||
let data: Data<f32, 1> = Data::from([999999., 999998., 999997.]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data, &Default::default());
|
||||
|
||||
let val_shared =
|
||||
Tensor::<TestBackend, 1, Int>::from_primitive(argmin::<TestRuntime, f32, i32, 1>(
|
||||
tensor.into_primitive(),
|
||||
0,
|
||||
ReduceStrategy::SharedMemory,
|
||||
));
|
||||
|
||||
assert_eq!(2, val_shared.into_data().value[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_argmin_shared_memory_extreme_values_i32() {
|
||||
let data: Data<i32, 1> = Data::from([999999, 999998, 999997]);
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::from_data(data, &Default::default());
|
||||
|
||||
let val_shared =
|
||||
Tensor::<TestBackend, 1, Int>::from_primitive(argmin::<TestRuntime, i32, i32, 1>(
|
||||
tensor.into_primitive(),
|
||||
0,
|
||||
ReduceStrategy::SharedMemory,
|
||||
));
|
||||
|
||||
assert_eq!(2, val_shared.into_data().value[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_argmax_shared_memory_extreme_values_i32() {
|
||||
let data: Data<i32, 1> = Data::from([-999999, -999997, -999998]);
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::from_data(data, &Default::default());
|
||||
|
||||
let val_shared =
|
||||
Tensor::<TestBackend, 1, Int>::from_primitive(argmax::<TestRuntime, i32, i32, 1>(
|
||||
tensor.into_primitive(),
|
||||
0,
|
||||
ReduceStrategy::SharedMemory,
|
||||
));
|
||||
|
||||
assert_eq!(1, val_shared.into_data().value[0]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -213,7 +213,12 @@ impl Display for Variable {
|
|||
Variable::GlobalScalar(number, _, elem) => {
|
||||
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
|
||||
}
|
||||
Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")),
|
||||
Variable::ConstantScalar(number, elem) => match elem {
|
||||
Elem::F32 => f.write_fmt(format_args!("{number}f")),
|
||||
Elem::I32 => f.write_fmt(format_args!("{number}i")),
|
||||
Elem::U32 => f.write_fmt(format_args!("{number}u")),
|
||||
Elem::Bool => f.write_fmt(format_args!("bool({number})")),
|
||||
},
|
||||
Variable::SharedMemory(number, _, _) => {
|
||||
f.write_fmt(format_args!("shared_memory_{number}"))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue