forked from OSchip/llvm-project
[mlir][Math] Add constant folder for sqrt.
Differential Revision: https://reviews.llvm.org/D121980
This commit is contained in:
parent
14bd14f9f9
commit
26c95ae389
|
@ -724,6 +724,7 @@ def Math_SqrtOp : Math_FloatUnaryOp<"sqrt"> {
|
|||
%x = math.sqrt %y : tensor<4x?xf32>
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -101,6 +101,31 @@ OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto attr = constOperand.dyn_cast<FloatAttr>();
|
||||
if (!attr)
|
||||
return {};
|
||||
|
||||
auto ft = getType().cast<FloatType>();
|
||||
|
||||
APFloat apf = attr.getValue();
|
||||
|
||||
if (apf.isNegative())
|
||||
return {};
|
||||
|
||||
if (ft.getWidth() == 64)
|
||||
return FloatAttr::get(getType(), sqrt(apf.convertToDouble()));
|
||||
|
||||
if (ft.getWidth() == 32)
|
||||
return FloatAttr::get(getType(), sqrtf(apf.convertToFloat()));
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Materialize an integer or floating point constant.
|
||||
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
|
|
|
@ -82,3 +82,12 @@ func @powf_fold() -> f32 {
|
|||
%r = math.powf %c, %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sqrt_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func @sqrt_fold() -> f32 {
|
||||
%c = arith.constant 4.0 : f32
|
||||
%r = math.sqrt %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue