Fix a bug in algebraic simplification, and enable the tests.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D107788
This commit is contained in:
bakhtiyar 2021-08-09 15:54:16 -07:00 committed by Eugene Zhulenev
parent a3290ea156
commit 391456f33c
2 changed files with 33 additions and 3 deletions

View File

@ -80,7 +80,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
return success();
}
// Replace `pow(x, 2.0)` with `x * x * x`.
// Replace `pow(x, 3.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
@ -95,12 +95,18 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
return success();
}
// Replace `pow(x, -2.0)` with `sqrt(x)`.
if (isExponentValue(-1.0)) {
// Replace `pow(x, 0.5)` with `sqrt(x)`.
if (isExponentValue(0.5)) {
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
return success();
}
// Replace `pow(x, -0.5)` with `rsqrt(x)`.
if (isExponentValue(-0.5)) {
rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
return success();
}
return failure();
}

View File

@ -49,3 +49,27 @@ func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}
// CHECK-LABEL: @pow_sqrt
func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: %[[SCALAR:.*]] = math.sqrt %arg0
// CHECK: %[[VECTOR:.*]] = math.sqrt %arg1
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = constant 0.5 : f32
%v = constant dense <0.5> : vector<4xf32>
%0 = math.powf %arg0, %c : f32
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}
// CHECK-LABEL: @pow_rsqrt
func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0
// CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = constant -0.5 : f32
%v = constant dense <-0.5> : vector<4xf32>
%0 = math.powf %arg0, %c : f32
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}