From e609417cdc934c6101ca512b00edcf47d9aa4211 Mon Sep 17 00:00:00 2001 From: jacquesguan Date: Mon, 21 Mar 2022 15:43:40 +0800 Subject: [PATCH] [mlir][Math] Add more constant folder for Math ops. This revision add constant folder for abs, copysign, ctlz, cttz and ctpop. Differential Revision: https://reviews.llvm.org/D122115 --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 5 + mlir/lib/Dialect/Math/IR/MathOps.cpp | 101 +++++++++++++++++++ mlir/test/Dialect/Math/canonicalize.mlir | 55 ++++++++++ 3 files changed, 161 insertions(+) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index b0ccbc21439b..221af3f6a5f2 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -89,6 +89,7 @@ def Math_AbsOp : Math_FloatUnaryOp<"abs"> { %x = math.abs %y : tensor<4x?xf8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -230,6 +231,7 @@ def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> { %x = math.copysign %y, %z : tensor<4x?xf8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -320,6 +322,7 @@ def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> { %x = math.ctlz %y : tensor<4x?xi8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -344,6 +347,7 @@ def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> { %x = math.cttz %y : tensor<4x?xi8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -368,6 +372,7 @@ def Math_CtPopOp : Math_IntegerUnaryOp<"ctpop"> { %x = math.ctpop %y : tensor<4x?xi8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 42f8334403c1..28f42f814f6d 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -20,6 +20,32 @@ using namespace mlir::math; #define GET_OP_CLASSES #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" +//===----------------------------------------------------------------------===// +// AbsOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::AbsOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + auto ft = getType().cast(); + + APFloat apf = attr.getValue(); + + if (ft.getWidth() == 64) + return FloatAttr::get(getType(), fabs(apf.convertToDouble())); + + if (ft.getWidth() == 32) + return FloatAttr::get(getType(), fabsf(apf.convertToFloat())); + + return {}; +} + //===----------------------------------------------------------------------===// // CeilOp folder //===----------------------------------------------------------------------===// @@ -39,6 +65,81 @@ OpFoldResult math::CeilOp::fold(ArrayRef operands) { return FloatAttr::get(getType(), sourceVal); } +//===----------------------------------------------------------------------===// +// CopySignOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CopySignOp::fold(ArrayRef operands) { + auto ft = getType().dyn_cast(); + if (!ft) + return {}; + + APFloat vals[2]{APFloat(ft.getFloatSemantics()), + APFloat(ft.getFloatSemantics())}; + for (int i = 0; i < 2; ++i) { + if (!operands[i]) + return {}; + + auto attr = operands[i].dyn_cast(); + if (!attr) + return {}; + + vals[i] = attr.getValue(); + } + + vals[0].copySign(vals[1]); + + return FloatAttr::get(getType(), vals[0]); +} + +//===----------------------------------------------------------------------===// +// CountLeadingZerosOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros()); +} + +//===----------------------------------------------------------------------===// +// CountTrailingZerosOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros()); +} + +//===----------------------------------------------------------------------===// +// CtPopOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CtPopOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + return IntegerAttr::get(getType(), attr.getValue().countPopulation()); +} + //===----------------------------------------------------------------------===// // Log2Op folder //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir index 27a92908eaec..45b13b455a2f 100644 --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -91,3 +91,58 @@ func @sqrt_fold() -> f32 { %r = math.sqrt %c : f32 return %r : f32 } + +// CHECK-LABEL: @abs_fold +// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32 +// CHECK: return %[[cst]] +func @abs_fold() -> f32 { + %c = arith.constant -4.0 : f32 + %r = math.abs %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @copysign_fold +// CHECK: %[[cst:.+]] = arith.constant -4.000000e+00 : f32 +// CHECK: return %[[cst]] +func @copysign_fold() -> f32 { + %c1 = arith.constant 4.0 : f32 + %c2 = arith.constant -9.0 : f32 + %r = math.copysign %c1, %c2 : f32 + return %r : f32 +} + +// CHECK-LABEL: @ctlz_fold1 +// CHECK: %[[cst:.+]] = arith.constant 31 : i32 +// CHECK: return %[[cst]] +func @ctlz_fold1() -> i32 { + %c = arith.constant 1 : i32 + %r = math.ctlz %c : i32 + return %r : i32 +} + +// CHECK-LABEL: @ctlz_fold2 +// CHECK: %[[cst:.+]] = arith.constant 7 : i8 +// CHECK: return %[[cst]] +func @ctlz_fold2() -> i8 { + %c = arith.constant 1 : i8 + %r = math.ctlz %c : i8 + return %r : i8 +} + +// CHECK-LABEL: @cttz_fold +// CHECK: %[[cst:.+]] = arith.constant 8 : i32 +// CHECK: return %[[cst]] +func @cttz_fold() -> i32 { + %c = arith.constant 256 : i32 + %r = math.cttz %c : i32 + return %r : i32 +} + +// CHECK-LABEL: @ctpop_fold +// CHECK: %[[cst:.+]] = arith.constant 16 : i32 +// CHECK: return %[[cst]] +func @ctpop_fold() -> i32 { + %c = arith.constant 0xFF0000FF : i32 + %r = math.ctpop %c : i32 + return %r : i32 +}