From a744c7e962d85a6c0b2de19eff840755ef5c2a1d Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 6 Oct 2021 06:45:42 +0000 Subject: [PATCH] [mlir][linalg] Update OpDSL to use the newly introduced min and max ops. Implement min and max using the newly introduced std operations instead of relying on compare and select. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D111170 --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 21 +++++------------- .../dialects/linalg/opdsl/lang/emitter.py | 22 ++++--------------- .../generalize-named-polymorphic-ops.mlir | 16 +++++--------- .../linalg/opdsl/emit_structured_generic.py | 10 ++++----- .../integration/dialects/linalg/opsrun.py | 6 ++--- 5 files changed, 23 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index af292878d1f6..69cd9e25e5d9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -276,18 +276,20 @@ public: } Value applyfn__max(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__min(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } @@ -324,17 +326,6 @@ private: MLIRContext *context; Block █ - Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) { - OpBuilder builder = getBuilder(); - Value condition = builder.create(lhs.getLoc(), predicate, lhs, rhs); - return builder.create(lhs.getLoc(), condition, lhs, rhs); - } - Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) { - OpBuilder builder = getBuilder(); - Value condition = builder.create(lhs.getLoc(), predicate, lhs, rhs); - return builder.create(lhs.getLoc(), condition, lhs, rhs); - } - bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index b151a9ba9f39..4a883e79037b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -319,20 +319,16 @@ class _BodyBuilder: def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - return _emit_cmpf_and_select(lhs, rhs, ogt_attr) + return std.MaxFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - return _emit_cmpi_and_select(lhs, rhs, sgt_attr) + return std.MaxSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - return _emit_cmpf_and_select(lhs, rhs, olt_attr) + return std.MinFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - return _emit_cmpi_and_select(lhs, rhs, slt_attr) + return std.MinSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") @@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int: if BF16Type.isinstance(t): return 16 raise NotImplementedError(f"Unhandled floating point type switch {t}") - - -def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: - cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result - - -def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: - cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir index 3e934d42012c..89fd83e585ee 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -38,8 +38,7 @@ func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: ten // CHECK-LABEL: @generalize_pooling_nhwc_max_f32 // CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) -// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32 -// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32 // CHECK-NEXT: linalg.yield %[[MAX]] : f32 // CHECK-NEXT: -> tensor<1x2x4x1xf32> @@ -53,8 +52,7 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten // CHECK-LABEL: @generalize_pooling_nhwc_max_i32 // CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) -// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32 -// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32 +// CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32 // CHECK-NEXT: linalg.yield %[[MAX]] : i32 // CHECK-NEXT: -> tensor<1x2x4x1xi32> @@ -68,9 +66,8 @@ func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: ten // CHECK-LABEL: @generalize_pooling_nhwc_min_f32 // CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) -// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32 -// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32 -// CHECK-NEXT: linalg.yield %[[MAX]] : f32 +// CHECK-NEXT: %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: linalg.yield %[[MIN]] : f32 // CHECK-NEXT: -> tensor<1x2x4x1xf32> // ----- @@ -83,9 +80,8 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten // CHECK-LABEL: @generalize_pooling_nhwc_min_i32 // CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) -// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32 -// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32 -// CHECK-NEXT: linalg.yield %[[MAX]] : i32 +// CHECK-NEXT: %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32 +// CHECK-NEXT: linalg.yield %[[MIN]] : i32 // CHECK-NEXT: -> tensor<1x2x4x1xi32> // ----- diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py index ed3364485901..16a82f63dbc8 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -242,8 +242,7 @@ with Context() as ctx, Location.unknown(): # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 - # CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32 - # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32 + # CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32 # CHECK-NEXT: linalg.yield %[[MAX]] : i32 # CHECK-NEXT: -> tensor<2x4xi32> @builtin.FuncOp.from_py_func( @@ -258,8 +257,7 @@ with Context() as ctx, Location.unknown(): # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32 - # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32 + # CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32 # CHECK-NEXT: linalg.yield %[[MAX]] : f32 # CHECK-NEXT: -> tensor<2x4xf32> @builtin.FuncOp.from_py_func( @@ -270,7 +268,7 @@ with Context() as ctx, Location.unknown(): input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32i32_min_pooling - # CHECK: = cmpi slt, + # CHECK: = minsi @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), i32)) @@ -279,7 +277,7 @@ with Context() as ctx, Location.unknown(): input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32f32_min_pooling - # CHECK: = cmpf olt, + # CHECK: = minf @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), f32)) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 04ee6c8dc5ee..5491193fa992 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -118,7 +118,7 @@ func @main() -> i32 attributes {llvm.emit_c_interface} { def transform(module, boilerplate): import mlir.conversions - import mlir.dialects.linalg.passes + import mlir.all_passes_registration import mlir.transforms # TODO: Allow cloning functions from one module to another. @@ -128,8 +128,8 @@ def transform(module, boilerplate): boilerplate) pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + - "convert-scf-to-std), convert-vector-to-llvm," + - "convert-memref-to-llvm,convert-std-to-llvm," + + "convert-scf-to-std, std-expand), convert-vector-to-llvm," + + "convert-memref-to-llvm, convert-std-to-llvm," + "reconcile-unrealized-casts") pm.run(mod) return mod