forked from OSchip/llvm-project
[mlir][linalg][python] Add min operation in OpDSL.
Add the min operation to OpDSL and introduce a min pooling operation to test the implementation. The patch is a sibling of the max operation patch https://reviews.llvm.org/D105203 and the min operation is again lowered to a compare and select pair. Differential Revision: https://reviews.llvm.org/D105345
This commit is contained in:
parent
7c5d654f64
commit
f239026f89
|
@ -664,6 +664,77 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: pooling_nhwc_min_poly
|
||||
cpp_class_name: PoolingNhwcMinPolyOp
|
||||
doc: |-
|
||||
Performs min pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
usage: InputOperand
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||
(s0, s1, s2, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: K
|
||||
usage: InputOperand
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||
(s4, s5)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: OutputOperand
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||
(s0, s6, s7, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||
-> (s8, s9)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: dilations
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||
-> (s10, s11)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||
s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||
s10, s11] -> (d3, d4)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||
s10, s11] -> (d0, d1, d2, d5)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
- reduction
|
||||
- parallel
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: min
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: fill_rng_2d
|
||||
cpp_class_name: FillRng2DOp
|
||||
|
|
|
@ -275,17 +275,18 @@ public:
|
|||
}
|
||||
|
||||
Value applyfn__max(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs)) {
|
||||
Value condition =
|
||||
builder.create<CmpFOp>(lhs.getLoc(), CmpFPredicate::OGT, lhs, rhs);
|
||||
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||
}
|
||||
if (isInteger(lhs)) {
|
||||
Value condition =
|
||||
builder.create<CmpIOp>(lhs.getLoc(), CmpIPredicate::sgt, lhs, rhs);
|
||||
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||
}
|
||||
if (isFloatingPoint(lhs))
|
||||
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
|
||||
if (isInteger(lhs))
|
||||
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
Value applyfn__min(Value lhs, Value rhs) {
|
||||
if (isFloatingPoint(lhs))
|
||||
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
|
||||
if (isInteger(lhs))
|
||||
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
|
@ -322,6 +323,17 @@ private:
|
|||
MLIRContext *context;
|
||||
Block █
|
||||
|
||||
Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
|
||||
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||
}
|
||||
Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
|
||||
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||
}
|
||||
|
||||
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
|
||||
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
|
||||
|
||||
|
|
|
@ -339,6 +339,7 @@ class PrimFn:
|
|||
log = PrimFnType("log")
|
||||
mul = PrimFnType("mul")
|
||||
max = PrimFnType("max")
|
||||
min = PrimFnType("min")
|
||||
sub = PrimFnType("sub")
|
||||
|
||||
|
||||
|
@ -364,6 +365,7 @@ class ReduceFn:
|
|||
add = PrimFn.add.reduce
|
||||
mul = PrimFn.mul.reduce
|
||||
max = PrimFn.max.reduce
|
||||
min = PrimFn.min.reduce
|
||||
|
||||
|
||||
class PrimApply(TensorExpression):
|
||||
|
|
|
@ -308,17 +308,23 @@ class _BodyBuilder:
|
|||
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
|
||||
|
||||
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
|
||||
i1 = IntegerType.get_signless(1)
|
||||
if _is_floating_point_type(lhs.type):
|
||||
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||
cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||
cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
|
||||
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
|
||||
|
||||
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||
return _emit_cmpf_and_select(lhs, rhs, olt_attr)
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||
return _emit_cmpi_and_select(lhs, rhs, slt_attr)
|
||||
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
|
||||
|
||||
|
||||
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
|
||||
in_arg_defs: Sequence[OperandDefConfig],
|
||||
|
@ -397,3 +403,13 @@ def _get_floating_point_width(t: Type) -> int:
|
|||
if BF16Type.isinstance(t):
|
||||
return 16
|
||||
raise NotImplementedError(f"Unhandled floating point type switch {t}")
|
||||
|
||||
|
||||
def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
|
||||
cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
|
||||
|
||||
def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
|
||||
cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
|
|
|
@ -166,6 +166,24 @@ def pooling_nhwc_max_poly(
|
|||
D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_nhwc_min_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=AttributeDef(S.SH, S.SW),
|
||||
dilations=AttributeDef(S.DH, S.DW)):
|
||||
"""Performs min pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
|
||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def fill_rng_2d(
|
||||
min=ScalarDef(F64),
|
||||
|
|
|
@ -90,6 +90,36 @@ func @generalize_pooling_nhwc_max_poly_i32(%input : tensor<1x4x16x1xi32>, %shape
|
|||
|
||||
// -----
|
||||
|
||||
func @generalize_pooling_nhwc_min_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
|
||||
%0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
||||
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
|
||||
return %0: tensor<1x2x4x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_f32
|
||||
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
|
||||
// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
|
||||
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @generalize_pooling_nhwc_min_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
|
||||
%0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
||||
ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
|
||||
return %0: tensor<1x2x4x1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_i32
|
||||
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
|
||||
// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
|
||||
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
|
||||
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
||||
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
|
||||
|
|
|
@ -43,7 +43,7 @@ def conv_poly(
|
|||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_poly(
|
||||
def pooling_max_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
|
@ -55,6 +55,19 @@ def pooling_poly(
|
|||
D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_min_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=AttributeDef(S.SH, S.SW),
|
||||
dilations=AttributeDef(S.DH, S.DW)):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
|
||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def fill_rng_poly(
|
||||
min=ScalarDef(F64),
|
||||
|
@ -216,7 +229,7 @@ with Context() as ctx, Location.unknown():
|
|||
return conv_poly(
|
||||
input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32i32_pooling
|
||||
# CHECK-LABEL: @test_f32i32_max_pooling
|
||||
# CHECK: linalg.generic
|
||||
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
|
||||
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
||||
|
@ -229,11 +242,11 @@ with Context() as ctx, Location.unknown():
|
|||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((2, 4), i32))
|
||||
def test_f32i32_pooling(input, shape, init_result):
|
||||
return pooling_poly(
|
||||
def test_f32i32_max_pooling(input, shape, init_result):
|
||||
return pooling_max_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32f32_pooling
|
||||
# CHECK-LABEL: @test_f32f32_max_pooling
|
||||
# CHECK: linalg.generic
|
||||
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
|
||||
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
||||
|
@ -245,8 +258,26 @@ with Context() as ctx, Location.unknown():
|
|||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((2, 4), f32))
|
||||
def test_f32f32_pooling(input, shape, init_result):
|
||||
return pooling_poly(
|
||||
def test_f32f32_max_pooling(input, shape, init_result):
|
||||
return pooling_max_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32i32_min_pooling
|
||||
# CHECK: = cmpi slt,
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((2, 4), i32))
|
||||
def test_f32i32_min_pooling(input, shape, init_result):
|
||||
return pooling_min_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32f32_min_pooling
|
||||
# CHECK: = cmpf olt,
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((2, 4), f32))
|
||||
def test_f32f32_min_pooling(input, shape, init_result):
|
||||
return pooling_min_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_i32_fill_rng
|
||||
|
|
|
@ -86,6 +86,8 @@ pooling_boiler = """
|
|||
func @main() -> i32 attributes {llvm.emit_c_interface} {
|
||||
%v0 = constant 0 : i32
|
||||
%v42 = constant 42.0 : f64
|
||||
%v77 = constant 77.0 : f64
|
||||
%v-13 = constant -13.0 : f64
|
||||
%v1 = constant 1.0 : f64
|
||||
|
||||
%input = memref.alloc() : memref<1x4x16x1xf64>
|
||||
|
@ -96,7 +98,11 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
|
|||
linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
|
||||
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
|
||||
memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
|
||||
memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
|
||||
|
||||
call @pooling_on_buffers(%input, %shape, %output) :
|
||||
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
|
||||
|
@ -301,7 +307,7 @@ def test_conv_generic():
|
|||
test_conv_generic()
|
||||
|
||||
|
||||
def test_pooling_builtin():
|
||||
def test_max_pooling_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f64 = F64Type.get()
|
||||
|
@ -325,13 +331,14 @@ def test_pooling_builtin():
|
|||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# 77 is not selected due to the dilation 2 in the second dimension.
|
||||
# CHECK: RESULT: 42
|
||||
|
||||
|
||||
test_pooling_builtin()
|
||||
test_max_pooling_builtin()
|
||||
|
||||
|
||||
def test_pooling_generic():
|
||||
def test_max_pooling_generic():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f64 = F64Type.get()
|
||||
|
@ -360,7 +367,73 @@ def test_pooling_generic():
|
|||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# 77 is not selected due to the dilation 2 in the second dimension.
|
||||
# CHECK: RESULT: 42
|
||||
|
||||
|
||||
test_pooling_generic()
|
||||
test_max_pooling_generic()
|
||||
|
||||
|
||||
def test_min_pooling_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f64 = F64Type.get()
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
||||
MemRefType.get((1, 2, 4, 1), i32))
|
||||
def pooling_on_buffers(input, shape, output):
|
||||
linalg.pooling_nhwc_min_poly(
|
||||
input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result i32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_int_p = ctypes.c_int * 1
|
||||
res = c_int_p(-1)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# CHECK: RESULT: -13
|
||||
|
||||
|
||||
test_min_pooling_builtin()
|
||||
|
||||
|
||||
def test_min_pooling_generic():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f64 = F64Type.get()
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
||||
MemRefType.get((1, 2, 4, 1), i32))
|
||||
def pooling_on_buffers(input, shape, output):
|
||||
linalg.pooling_nhwc_min_poly(
|
||||
input,
|
||||
shape,
|
||||
outs=[output],
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2],
|
||||
emit_generic=True)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result i32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_int_p = ctypes.c_int * 1
|
||||
res = c_int_p(-1)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# CHECK: RESULT: -13
|
||||
|
||||
|
||||
test_min_pooling_generic()
|
||||
|
|
Loading…
Reference in New Issue