[mlir][Math] Add more constant folder for Math ops.

This revision add constant folder for abs, copysign, ctlz, cttz and
ctpop.

Differential Revision: https://reviews.llvm.org/D122115
This commit is contained in:
jacquesguan 2022-03-21 15:43:40 +08:00
parent 8e64d84995
commit e609417cdc
3 changed files with 161 additions and 0 deletions

View File

@ -89,6 +89,7 @@ def Math_AbsOp : Math_FloatUnaryOp<"abs"> {
%x = math.abs %y : tensor<4x?xf8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@ -230,6 +231,7 @@ def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> {
%x = math.copysign %y, %z : tensor<4x?xf8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@ -320,6 +322,7 @@ def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
%x = math.ctlz %y : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@ -344,6 +347,7 @@ def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
%x = math.cttz %y : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@ -368,6 +372,7 @@ def Math_CtPopOp : Math_IntegerUnaryOp<"ctpop"> {
%x = math.ctpop %y : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -20,6 +20,32 @@ using namespace mlir::math;
#define GET_OP_CLASSES
#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
//===----------------------------------------------------------------------===//
// AbsOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = operands.front();
if (!constOperand)
return {};
auto attr = constOperand.dyn_cast<FloatAttr>();
if (!attr)
return {};
auto ft = getType().cast<FloatType>();
APFloat apf = attr.getValue();
if (ft.getWidth() == 64)
return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
if (ft.getWidth() == 32)
return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
return {};
}
//===----------------------------------------------------------------------===//
// CeilOp folder
//===----------------------------------------------------------------------===//
@ -39,6 +65,81 @@ OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
return FloatAttr::get(getType(), sourceVal);
}
//===----------------------------------------------------------------------===//
// CopySignOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
auto ft = getType().dyn_cast<FloatType>();
if (!ft)
return {};
APFloat vals[2]{APFloat(ft.getFloatSemantics()),
APFloat(ft.getFloatSemantics())};
for (int i = 0; i < 2; ++i) {
if (!operands[i])
return {};
auto attr = operands[i].dyn_cast<FloatAttr>();
if (!attr)
return {};
vals[i] = attr.getValue();
}
vals[0].copySign(vals[1]);
return FloatAttr::get(getType(), vals[0]);
}
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = operands.front();
if (!constOperand)
return {};
auto attr = constOperand.dyn_cast<IntegerAttr>();
if (!attr)
return {};
return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
}
//===----------------------------------------------------------------------===//
// CountTrailingZerosOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = operands.front();
if (!constOperand)
return {};
auto attr = constOperand.dyn_cast<IntegerAttr>();
if (!attr)
return {};
return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
}
//===----------------------------------------------------------------------===//
// CtPopOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = operands.front();
if (!constOperand)
return {};
auto attr = constOperand.dyn_cast<IntegerAttr>();
if (!attr)
return {};
return IntegerAttr::get(getType(), attr.getValue().countPopulation());
}
//===----------------------------------------------------------------------===//
// Log2Op folder
//===----------------------------------------------------------------------===//

View File

@ -91,3 +91,58 @@ func @sqrt_fold() -> f32 {
%r = math.sqrt %c : f32
return %r : f32
}
// CHECK-LABEL: @abs_fold
// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
// CHECK: return %[[cst]]
func @abs_fold() -> f32 {
%c = arith.constant -4.0 : f32
%r = math.abs %c : f32
return %r : f32
}
// CHECK-LABEL: @copysign_fold
// CHECK: %[[cst:.+]] = arith.constant -4.000000e+00 : f32
// CHECK: return %[[cst]]
func @copysign_fold() -> f32 {
%c1 = arith.constant 4.0 : f32
%c2 = arith.constant -9.0 : f32
%r = math.copysign %c1, %c2 : f32
return %r : f32
}
// CHECK-LABEL: @ctlz_fold1
// CHECK: %[[cst:.+]] = arith.constant 31 : i32
// CHECK: return %[[cst]]
func @ctlz_fold1() -> i32 {
%c = arith.constant 1 : i32
%r = math.ctlz %c : i32
return %r : i32
}
// CHECK-LABEL: @ctlz_fold2
// CHECK: %[[cst:.+]] = arith.constant 7 : i8
// CHECK: return %[[cst]]
func @ctlz_fold2() -> i8 {
%c = arith.constant 1 : i8
%r = math.ctlz %c : i8
return %r : i8
}
// CHECK-LABEL: @cttz_fold
// CHECK: %[[cst:.+]] = arith.constant 8 : i32
// CHECK: return %[[cst]]
func @cttz_fold() -> i32 {
%c = arith.constant 256 : i32
%r = math.cttz %c : i32
return %r : i32
}
// CHECK-LABEL: @ctpop_fold
// CHECK: %[[cst:.+]] = arith.constant 16 : i32
// CHECK: return %[[cst]]
func @ctpop_fold() -> i32 {
%c = arith.constant 0xFF0000FF : i32
%r = math.ctpop %c : i32
return %r : i32
}