forked from OSchip/llvm-project
[mlir][Math] Add more constant folder for Math ops.
This revision add constant folder for abs, copysign, ctlz, cttz and ctpop. Differential Revision: https://reviews.llvm.org/D122115
This commit is contained in:
parent
8e64d84995
commit
e609417cdc
|
@ -89,6 +89,7 @@ def Math_AbsOp : Math_FloatUnaryOp<"abs"> {
|
|||
%x = math.abs %y : tensor<4x?xf8>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -230,6 +231,7 @@ def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> {
|
|||
%x = math.copysign %y, %z : tensor<4x?xf8>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -320,6 +322,7 @@ def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
|
|||
%x = math.ctlz %y : tensor<4x?xi8>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -344,6 +347,7 @@ def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
|
|||
%x = math.cttz %y : tensor<4x?xi8>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -368,6 +372,7 @@ def Math_CtPopOp : Math_IntegerUnaryOp<"ctpop"> {
|
|||
%x = math.ctpop %y : tensor<4x?xi8>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -20,6 +20,32 @@ using namespace mlir::math;
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AbsOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<FloatAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
auto ft = getType().cast<FloatType>();
|
||||
|
||||
APFloat apf = attr.getValue();
|
||||
|
||||
if (ft.getWidth() == 64)
|
||||
return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
|
||||
|
||||
if (ft.getWidth() == 32)
|
||||
return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CeilOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -39,6 +65,81 @@ OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
|
|||
return FloatAttr::get(getType(), sourceVal);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopySignOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto ft = getType().dyn_cast<FloatType>();
|
||||
if (!ft)
|
||||
return {};
|
||||
|
||||
APFloat vals[2]{APFloat(ft.getFloatSemantics()),
|
||||
APFloat(ft.getFloatSemantics())};
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (!operands[i])
|
||||
return {};
|
||||
|
||||
auto attr = operands[i].dyn_cast<FloatAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
vals[i] = attr.getValue();
|
||||
}
|
||||
|
||||
vals[0].copySign(vals[1]);
|
||||
|
||||
return FloatAttr::get(getType(), vals[0]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CountLeadingZerosOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<IntegerAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CountTrailingZerosOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<IntegerAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CtPopOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<IntegerAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
return IntegerAttr::get(getType(), attr.getValue().countPopulation());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Log2Op folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -91,3 +91,58 @@ func @sqrt_fold() -> f32 {
|
|||
%r = math.sqrt %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @abs_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @abs_fold() -> f32 {
|
||||
%c = arith.constant -4.0 : f32
|
||||
%r = math.abs %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @copysign_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant -4.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @copysign_fold() -> f32 {
|
||||
%c1 = arith.constant 4.0 : f32
|
||||
%c2 = arith.constant -9.0 : f32
|
||||
%r = math.copysign %c1, %c2 : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctlz_fold1
|
||||
// CHECK: %[[cst:.+]] = arith.constant 31 : i32
|
||||
// CHECK: return %[[cst]]
|
||||
func @ctlz_fold1() -> i32 {
|
||||
%c = arith.constant 1 : i32
|
||||
%r = math.ctlz %c : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctlz_fold2
|
||||
// CHECK: %[[cst:.+]] = arith.constant 7 : i8
|
||||
// CHECK: return %[[cst]]
|
||||
func @ctlz_fold2() -> i8 {
|
||||
%c = arith.constant 1 : i8
|
||||
%r = math.ctlz %c : i8
|
||||
return %r : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cttz_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 8 : i32
|
||||
// CHECK: return %[[cst]]
|
||||
func @cttz_fold() -> i32 {
|
||||
%c = arith.constant 256 : i32
|
||||
%r = math.cttz %c : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctpop_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 16 : i32
|
||||
// CHECK: return %[[cst]]
|
||||
func @ctpop_fold() -> i32 {
|
||||
%c = arith.constant 0xFF0000FF : i32
|
||||
%r = math.ctpop %c : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue