forked from OSchip/llvm-project
[mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect.
I don't see a point here in the lit tests here since sqrt, mul and other ops expand as well. I just added "smoke" tests to verify that the conversion works and does not create any illegal ops. I will create a patch that adds a simple integration test to mlir/test/Integration/Dialect/ComplexOps/ that will compare the values. Differential Revision: https://reviews.llvm.org/D126539
This commit is contained in:
parent
bcf3d52486
commit
f5fa633b09
|
@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
|
||||
struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
|
||||
using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
auto type = op.getType().cast<ComplexType>();
|
||||
Type elementType = type.getElementType();
|
||||
|
||||
Value lhs = adaptor.getLhs();
|
||||
Value rhs = adaptor.getRhs();
|
||||
|
||||
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
|
||||
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
|
||||
Value rhsSquaredPlusLhsSquared =
|
||||
b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
|
||||
Value sqrtOfRhsSquaredPlusLhsSquared =
|
||||
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
|
||||
|
||||
Value zero =
|
||||
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
|
||||
Value one = b.create<arith::ConstantOp>(elementType,
|
||||
b.getFloatAttr(elementType, 1));
|
||||
Value i = b.create<complex::CreateOp>(type, zero, one);
|
||||
Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
|
||||
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
|
||||
|
||||
Value divResult =
|
||||
b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
|
||||
Value logResult = b.create<complex::LogOp>(divResult);
|
||||
|
||||
Value negativeOne = b.create<arith::ConstantOp>(
|
||||
elementType, b.getFloatAttr(elementType, -1));
|
||||
Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
|
||||
|
||||
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ComparisonOp, arith::CmpFPredicate p>
|
||||
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
|
||||
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
|
||||
|
@ -700,6 +743,72 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
|
||||
struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
|
||||
using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
auto type = op.getType().cast<ComplexType>();
|
||||
Type elementType = type.getElementType();
|
||||
Value arg = adaptor.getComplex();
|
||||
|
||||
Value zero =
|
||||
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
|
||||
|
||||
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
|
||||
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
|
||||
|
||||
Value absLhs = b.create<math::AbsOp>(real);
|
||||
Value absArg = b.create<complex::AbsOp>(elementType, arg);
|
||||
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
|
||||
Value sqrtAddAbs = b.create<math::SqrtOp>(addAbs);
|
||||
Value sqrtAddAbsDivTwo = b.create<arith::DivFOp>(
|
||||
sqrtAddAbs, b.create<arith::ConstantOp>(
|
||||
elementType, b.getFloatAttr(elementType, 2)));
|
||||
|
||||
Value realIsNegative =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
|
||||
Value imagIsNegative =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
|
||||
|
||||
Value resultReal = sqrtAddAbsDivTwo;
|
||||
|
||||
Value imagDivTwoResultReal = b.create<arith::DivFOp>(
|
||||
imag, b.create<arith::AddFOp>(resultReal, resultReal));
|
||||
|
||||
Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
|
||||
|
||||
Value resultImag = b.create<arith::SelectOp>(
|
||||
realIsNegative,
|
||||
b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
|
||||
resultReal),
|
||||
imagDivTwoResultReal);
|
||||
|
||||
resultReal = b.create<arith::SelectOp>(
|
||||
realIsNegative,
|
||||
b.create<arith::DivFOp>(
|
||||
imag, b.create<arith::AddFOp>(resultImag, resultImag)),
|
||||
resultReal);
|
||||
|
||||
Value realIsZero =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
|
||||
Value imagIsZero =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
|
||||
Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
|
||||
|
||||
resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
|
||||
resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
|
||||
|
||||
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
||||
resultImag);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
|
||||
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
|
||||
|
||||
|
@ -735,6 +844,7 @@ void mlir::populateComplexToStandardConversionPatterns(
|
|||
// clang-format off
|
||||
patterns.add<
|
||||
AbsOpConversion,
|
||||
Atan2OpConversion,
|
||||
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
|
||||
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
|
||||
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
|
||||
|
@ -748,7 +858,8 @@ void mlir::populateComplexToStandardConversionPatterns(
|
|||
MulOpConversion,
|
||||
NegOpConversion,
|
||||
SignOpConversion,
|
||||
SinOpConversion>(patterns.getContext());
|
||||
SinOpConversion,
|
||||
SqrtOpConversion>(patterns.getContext());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
|
||||
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @complex_abs
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
|
@ -14,6 +14,17 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
|
|||
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
|
||||
// CHECK: return %[[NORM]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_atan2
|
||||
func.func @complex_atan2(%lhs: complex<f32>,
|
||||
%rhs: complex<f32>) -> complex<f32> {
|
||||
%atan2 = complex.atan2 %lhs, %rhs : complex<f32>
|
||||
return %atan2 : complex<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_add
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
|
@ -29,6 +40,8 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_cos
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -50,6 +63,8 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_div
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
|
@ -159,6 +174,8 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_eq
|
||||
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
|
||||
func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
|
||||
|
@ -174,6 +191,8 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
|
|||
// CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
|
||||
// CHECK: return %[[EQUAL]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_exp
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -190,6 +209,8 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @complex_expm1(
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
|
||||
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -211,6 +232,8 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RES]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_log
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -230,6 +253,8 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_log1p
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -254,6 +279,8 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_mul
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
|
@ -372,6 +399,8 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_neg
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -385,6 +414,8 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_neq
|
||||
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
|
||||
func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
|
||||
|
@ -400,6 +431,8 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
|
|||
// CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
|
||||
// CHECK: return %[[NOT_EQUAL]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_sin
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -421,6 +454,8 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_sign
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
|
||||
|
@ -445,6 +480,8 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_sub
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
|
@ -459,3 +496,11 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
|
||||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_sqrt
|
||||
func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
|
||||
%sqrt = complex.sqrt %arg : complex<f32>
|
||||
return %sqrt : complex<f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue