forked from OSchip/llvm-project
[MLIR][Math] Enable constant folding of ops
Enable constant folding of ops within the math dialect, and introduce constant folders for ceil and log2 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D117085
This commit is contained in:
parent
aaa0c81683
commit
2f8b956ab6
|
@ -15,6 +15,7 @@ def Math_Dialect : Dialect {
|
|||
The math dialect is intended to hold mathematical operations on integer and
|
||||
floating type beyond simple arithmetics.
|
||||
}];
|
||||
let hasConstantMaterializer = 1;
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||
}
|
||||
#endif // MATH_BASE
|
||||
|
|
|
@ -195,6 +195,7 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
|
|||
%x = math.ceil %y : tensor<4x?xf8>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -649,6 +650,7 @@ def Math_Log2Op : Math_FloatUnaryOp<"log2"> {
|
|||
%y = math.log2 %x : f64
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6,7 +6,9 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::math;
|
||||
|
@ -17,3 +19,58 @@ using namespace mlir::math;
|
|||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CeilOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<FloatAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
APFloat sourceVal = attr.getValue();
|
||||
sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
|
||||
|
||||
return FloatAttr::get(getType(), sourceVal);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Log2Op folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<FloatAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
auto FT = getType().cast<FloatType>();
|
||||
|
||||
APFloat APF = attr.getValue();
|
||||
|
||||
if (APF.isNegative())
|
||||
return {};
|
||||
|
||||
if (FT.getWidth() == 64)
|
||||
return FloatAttr::get(getType(), log2(APF.convertToDouble()));
|
||||
|
||||
if (FT.getWidth() == 32)
|
||||
return FloatAttr::get(getType(), log2f(APF.convertToDouble()));
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Materialize an integer or floating point constant.
|
||||
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, value, type);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @ceil_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @ceil_fold() -> f32 {
|
||||
%c = arith.constant 0.3 : f32
|
||||
%r = math.ceil %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ceil_fold2
|
||||
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @ceil_fold2() -> f32 {
|
||||
%c = arith.constant 2.0 : f32
|
||||
%r = math.ceil %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @log2_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @log2_fold() -> f32 {
|
||||
%c = arith.constant 4.0 : f32
|
||||
%r = math.log2 %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @log2_fold2
|
||||
// CHECK: %[[cst:.+]] = arith.constant 0xFF800000 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @log2_fold2() -> f32 {
|
||||
%c = arith.constant 0.0 : f32
|
||||
%r = math.log2 %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @log2_nofold2
|
||||
// CHECK: %[[cst:.+]] = arith.constant -1.000000e+00 : f32
|
||||
// CHECK: %[[res:.+]] = math.log2 %[[cst]] : f32
|
||||
// CHECK: return %[[res]]
|
||||
func @log2_nofold2() -> f32 {
|
||||
%c = arith.constant -1.0 : f32
|
||||
%r = math.log2 %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @log2_fold_64
|
||||
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64
|
||||
// CHECK: return %[[cst]]
|
||||
func @log2_fold_64() -> f64 {
|
||||
%c = arith.constant 4.0 : f64
|
||||
%r = math.log2 %c : f64
|
||||
return %r : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @log2_fold2_64
|
||||
// CHECK: %[[cst:.+]] = arith.constant 0xFFF0000000000000 : f64
|
||||
// CHECK: return %[[cst]]
|
||||
func @log2_fold2_64() -> f64 {
|
||||
%c = arith.constant 0.0 : f64
|
||||
%r = math.log2 %c : f64
|
||||
return %r : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @log2_nofold2_64
|
||||
// CHECK: %[[cst:.+]] = arith.constant -1.000000e+00 : f64
|
||||
// CHECK: %[[res:.+]] = math.log2 %[[cst]] : f64
|
||||
// CHECK: return %[[res]]
|
||||
func @log2_nofold2_64() -> f64 {
|
||||
%c = arith.constant -1.0 : f64
|
||||
%r = math.log2 %c : f64
|
||||
return %r : f64
|
||||
}
|
Loading…
Reference in New Issue