This commit is contained in:
louisfd 2024-03-22 08:25:49 -04:00
parent aa9de1a88a
commit 71b04ca0a1
7 changed files with 95 additions and 6 deletions

View File

@ -275,6 +275,12 @@ macro_rules! gpu {
gpu!(unary $input, $out)
));
};
// out = floor(input)
($scope:expr, $out:ident = floor($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Floor(
gpu!(unary $input, $out)
));
};
// out = ceil(input)
($scope:expr, $out:ident = ceil($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Ceil(

View File

@ -36,6 +36,8 @@ pub enum Operator {
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Floor(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),

View File

@ -43,6 +43,8 @@ impl Operator {
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)),
Operator::Log(op) => Operator::Log(op.vectorize(vectorization)),

View File

@ -321,6 +321,16 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Floor(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Ceil(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Modulo(op) => mark_binary(
op,
&mut local_tensor_ids_input,

View File

@ -77,11 +77,10 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
gpu!(scope, ow = id / output_stride_3);
gpu!(scope, ow = ow % output_shape_3);
let ih_start = scope.create_local(Elem::UInt);
let ih_end = scope.create_local(Elem::UInt);
let iw_start = scope.create_local(Elem::UInt);
let iw_end = scope.create_local(Elem::UInt);
// TODO COMPUTE THEM ^
let ih_start = Self::start_index(scope, oh, output_shape_2, input_shape_2);
let ih_end = Self::end_index(scope, oh, output_shape_2, input_shape_2);
let iw_start = Self::start_index(scope, ow, output_shape_3, input_shape_3);
let iw_end = Self::end_index(scope, ow, output_shape_3, input_shape_3);
let result = scope.create_local(input.item());
@ -129,7 +128,55 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
gpu!(scope, count_float = cast(count));
gpu!(scope, avg = sum / count_float);
gpu!(scope, output[id] = sum);
gpu!(scope, output[id] = avg);
}
fn start_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index * input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = floor(div));
gpu!(scope, index = cast(div));
index
}
fn end_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index + 1u32);
gpu!(scope, index *= input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = ceil(div));
gpu!(scope, index = cast(div));
gpu!(scope, min = input_size < index);
gpu!(scope, if(min).then(|scope|{
gpu!(scope, end_index = input_size);
}).else(|scope|{
gpu!(scope, end_index = index);
}));
end_index
}
}

View File

@ -445,6 +445,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Floor(op) => wgsl::Instruction::Floor {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Erf(op) => wgsl::Instruction::Erf {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),

View File

@ -210,6 +210,14 @@ pub enum Instruction {
rhs: Variable,
out: Variable,
},
Floor {
input: Variable,
out: Variable,
},
Ceil {
input: Variable,
out: Variable,
},
}
impl Display for Instruction {
@ -461,6 +469,12 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
Instruction::ShiftRight { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
}
Instruction::Floor { input, out } => {
f.write_fmt(format_args!("{out} = floor({input});\n"))
}
Instruction::Ceil { input, out } => {
f.write_fmt(format_args!("{out} = ceil({input});\n"))
}
}
}
}