forked from OSchip/llvm-project
[mlir][linalg] Update OpDSL to use the newly introduced min and max ops.
Implement min and max using the newly introduced std operations instead of relying on compare and select. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D111170
This commit is contained in:
parent
24af1ba605
commit
a744c7e962
|
@ -276,18 +276,20 @@ public:
|
|||
}
|
||||
|
||||
Value applyfn__max(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
|
||||
return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
|
||||
return builder.create<MaxSIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
Value applyfn__min(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
|
||||
return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
|
||||
return builder.create<MinSIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
|
@ -324,17 +326,6 @@ private:
|
|||
MLIRContext *context;
|
||||
Block █
|
||||
|
||||
Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
|
||||
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||
}
|
||||
Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
|
||||
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||
}
|
||||
|
||||
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
|
||||
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
|
||||
|
||||
|
|
|
@ -319,20 +319,16 @@ class _BodyBuilder:
|
|||
|
||||
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||
return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
|
||||
return std.MaxFOp(lhs.type, lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||
return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
|
||||
return std.MaxSIOp(lhs.type, lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
|
||||
|
||||
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||
return _emit_cmpf_and_select(lhs, rhs, olt_attr)
|
||||
return std.MinFOp(lhs.type, lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||
return _emit_cmpi_and_select(lhs, rhs, slt_attr)
|
||||
return std.MinSIOp(lhs.type, lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
|
||||
|
||||
|
||||
|
@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int:
|
|||
if BF16Type.isinstance(t):
|
||||
return 16
|
||||
raise NotImplementedError(f"Unhandled floating point type switch {t}")
|
||||
|
||||
|
||||
def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
|
||||
cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
|
||||
|
||||
def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
|
||||
cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
|
|
|
@ -38,8 +38,7 @@ func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
|
|||
|
||||
// CHECK-LABEL: @generalize_pooling_nhwc_max_f32
|
||||
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
|
||||
// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
|
||||
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
|
||||
|
||||
|
@ -53,8 +52,7 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
|
|||
|
||||
// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
|
||||
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
|
||||
// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
|
||||
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
|
||||
|
||||
|
@ -68,9 +66,8 @@ func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
|
|||
|
||||
// CHECK-LABEL: @generalize_pooling_nhwc_min_f32
|
||||
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
|
||||
// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
|
||||
// CHECK-NEXT: %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[MIN]] : f32
|
||||
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
|
||||
|
||||
// -----
|
||||
|
@ -83,9 +80,8 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
|
|||
|
||||
// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
|
||||
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
|
||||
// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
|
||||
// CHECK-NEXT: %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[MIN]] : i32
|
||||
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
|
||||
|
||||
// -----
|
||||
|
|
|
@ -242,8 +242,7 @@ with Context() as ctx, Location.unknown():
|
|||
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
||||
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
|
||||
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
|
||||
# CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32
|
||||
# CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32
|
||||
# CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
|
||||
# CHECK-NEXT: linalg.yield %[[MAX]] : i32
|
||||
# CHECK-NEXT: -> tensor<2x4xi32>
|
||||
@builtin.FuncOp.from_py_func(
|
||||
|
@ -258,8 +257,7 @@ with Context() as ctx, Location.unknown():
|
|||
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
|
||||
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
||||
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
|
||||
# CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32
|
||||
# CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32
|
||||
# CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32
|
||||
# CHECK-NEXT: linalg.yield %[[MAX]] : f32
|
||||
# CHECK-NEXT: -> tensor<2x4xf32>
|
||||
@builtin.FuncOp.from_py_func(
|
||||
|
@ -270,7 +268,7 @@ with Context() as ctx, Location.unknown():
|
|||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32i32_min_pooling
|
||||
# CHECK: = cmpi slt,
|
||||
# CHECK: = minsi
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((2, 4), i32))
|
||||
|
@ -279,7 +277,7 @@ with Context() as ctx, Location.unknown():
|
|||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32f32_min_pooling
|
||||
# CHECK: = cmpf olt,
|
||||
# CHECK: = minf
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((2, 4), f32))
|
||||
|
|
|
@ -118,7 +118,7 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
|
|||
|
||||
def transform(module, boilerplate):
|
||||
import mlir.conversions
|
||||
import mlir.dialects.linalg.passes
|
||||
import mlir.all_passes_registration
|
||||
import mlir.transforms
|
||||
|
||||
# TODO: Allow cloning functions from one module to another.
|
||||
|
@ -128,8 +128,8 @@ def transform(module, boilerplate):
|
|||
boilerplate)
|
||||
pm = PassManager.parse(
|
||||
"builtin.func(convert-linalg-to-loops, lower-affine, " +
|
||||
"convert-scf-to-std), convert-vector-to-llvm," +
|
||||
"convert-memref-to-llvm,convert-std-to-llvm," +
|
||||
"convert-scf-to-std, std-expand), convert-vector-to-llvm," +
|
||||
"convert-memref-to-llvm, convert-std-to-llvm," +
|
||||
"reconcile-unrealized-casts")
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
|
Loading…
Reference in New Issue