mirror of https://github.com/tracel-ai/burn.git
adaptive
This commit is contained in:
parent
aa9de1a88a
commit
71b04ca0a1
|
@ -275,6 +275,12 @@ macro_rules! gpu {
|
|||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = floor(input)
|
||||
($scope:expr, $out:ident = floor($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Floor(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = ceil(input)
|
||||
($scope:expr, $out:ident = ceil($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Ceil(
|
||||
|
|
|
@ -36,6 +36,8 @@ pub enum Operator {
|
|||
Tanh(UnaryOperator),
|
||||
Powf(BinaryOperator),
|
||||
Sqrt(UnaryOperator),
|
||||
Floor(UnaryOperator),
|
||||
Ceil(UnaryOperator),
|
||||
Erf(UnaryOperator),
|
||||
Recip(UnaryOperator),
|
||||
Equal(BinaryOperator),
|
||||
|
|
|
@ -43,6 +43,8 @@ impl Operator {
|
|||
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
|
||||
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
|
||||
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
|
||||
Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)),
|
||||
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
|
||||
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
|
||||
Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)),
|
||||
Operator::Log(op) => Operator::Log(op.vectorize(vectorization)),
|
||||
|
|
|
@ -321,6 +321,16 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Floor(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Ceil(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Modulo(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
|
|
|
@ -77,11 +77,10 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
|
|||
gpu!(scope, ow = id / output_stride_3);
|
||||
gpu!(scope, ow = ow % output_shape_3);
|
||||
|
||||
let ih_start = scope.create_local(Elem::UInt);
|
||||
let ih_end = scope.create_local(Elem::UInt);
|
||||
let iw_start = scope.create_local(Elem::UInt);
|
||||
let iw_end = scope.create_local(Elem::UInt);
|
||||
// TODO COMPUTE THEM ^
|
||||
let ih_start = Self::start_index(scope, oh, output_shape_2, input_shape_2);
|
||||
let ih_end = Self::end_index(scope, oh, output_shape_2, input_shape_2);
|
||||
let iw_start = Self::start_index(scope, ow, output_shape_3, input_shape_3);
|
||||
let iw_end = Self::end_index(scope, ow, output_shape_3, input_shape_3);
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
|
||||
|
@ -129,7 +128,55 @@ impl<R: Runtime, E: JitElement> AdaptivePool2dComputeShader<R, E> {
|
|||
|
||||
gpu!(scope, count_float = cast(count));
|
||||
gpu!(scope, avg = sum / count_float);
|
||||
gpu!(scope, output[id] = sum);
|
||||
gpu!(scope, output[id] = avg);
|
||||
}
|
||||
|
||||
fn start_index(
|
||||
scope: &mut Scope,
|
||||
output_size_index: Variable,
|
||||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index * input_size);
|
||||
gpu!(scope, numerator_float = cast(index));
|
||||
gpu!(scope, div = cast(output_size));
|
||||
gpu!(scope, div = numerator_float / div);
|
||||
gpu!(scope, div = floor(div));
|
||||
gpu!(scope, index = cast(div));
|
||||
index
|
||||
}
|
||||
|
||||
fn end_index(
|
||||
scope: &mut Scope,
|
||||
output_size_index: Variable,
|
||||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index + 1u32);
|
||||
gpu!(scope, index *= input_size);
|
||||
gpu!(scope, numerator_float = cast(index));
|
||||
gpu!(scope, div = cast(output_size));
|
||||
gpu!(scope, div = numerator_float / div);
|
||||
gpu!(scope, div = ceil(div));
|
||||
gpu!(scope, index = cast(div));
|
||||
|
||||
gpu!(scope, min = input_size < index);
|
||||
gpu!(scope, if(min).then(|scope|{
|
||||
gpu!(scope, end_index = input_size);
|
||||
}).else(|scope|{
|
||||
gpu!(scope, end_index = index);
|
||||
}));
|
||||
end_index
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -445,6 +445,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Floor(op) => wgsl::Instruction::Floor {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Erf(op) => wgsl::Instruction::Erf {
|
||||
input: self.compile_variable(op.input),
|
||||
out: self.compile_variable(op.out),
|
||||
|
|
|
@ -210,6 +210,14 @@ pub enum Instruction {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Floor {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Ceil {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Instruction {
|
||||
|
@ -461,6 +469,12 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
Instruction::ShiftRight { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
|
||||
}
|
||||
Instruction::Floor { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = floor({input});\n"))
|
||||
}
|
||||
Instruction::Ceil { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = ceil({input});\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue