forked from OSchip/llvm-project
Fix a bug in algebraic simplification, and enable the tests.
Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D107788
This commit is contained in:
parent
a3290ea156
commit
391456f33c
|
@ -80,7 +80,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, 2.0)` with `x * x * x`.
|
||||
// Replace `pow(x, 3.0)` with `x * x * x`.
|
||||
if (isExponentValue(3.0)) {
|
||||
Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
|
||||
rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
|
||||
|
@ -95,12 +95,18 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, -2.0)` with `sqrt(x)`.
|
||||
if (isExponentValue(-1.0)) {
|
||||
// Replace `pow(x, 0.5)` with `sqrt(x)`.
|
||||
if (isExponentValue(0.5)) {
|
||||
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, -0.5)` with `rsqrt(x)`.
|
||||
if (isExponentValue(-0.5)) {
|
||||
rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -49,3 +49,27 @@ func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
|||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pow_sqrt
|
||||
func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
||||
// CHECK: %[[SCALAR:.*]] = math.sqrt %arg0
|
||||
// CHECK: %[[VECTOR:.*]] = math.sqrt %arg1
|
||||
// CHECK: return %[[SCALAR]], %[[VECTOR]]
|
||||
%c = constant 0.5 : f32
|
||||
%v = constant dense <0.5> : vector<4xf32>
|
||||
%0 = math.powf %arg0, %c : f32
|
||||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pow_rsqrt
|
||||
func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
||||
// CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0
|
||||
// CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1
|
||||
// CHECK: return %[[SCALAR]], %[[VECTOR]]
|
||||
%c = constant -0.5 : f32
|
||||
%v = constant dense <-0.5> : vector<4xf32>
|
||||
%0 = math.powf %arg0, %c : f32
|
||||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue