[mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect.

I don't see a point here in the lit tests here since sqrt, mul and other ops
expand as well. I just added "smoke" tests to verify that the conversion works
and does not create any illegal ops.

I will create a patch that adds a simple integration test to
mlir/test/Integration/Dialect/ComplexOps/ that will compare the values.

Differential Revision: https://reviews.llvm.org/D126539
This commit is contained in:
Alexander Belyaev 2022-05-27 15:59:19 +02:00
parent bcf3d52486
commit f5fa633b09
2 changed files with 158 additions and 2 deletions

View File

@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
}
};
// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto type = op.getType().cast<ComplexType>();
Type elementType = type.getElementType();
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
Value rhsSquaredPlusLhsSquared =
b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
Value sqrtOfRhsSquaredPlusLhsSquared =
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
Value i = b.create<complex::CreateOp>(type, zero, one);
Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
Value divResult =
b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
Value logResult = b.create<complex::LogOp>(divResult);
Value negativeOne = b.create<arith::ConstantOp>(
elementType, b.getFloatAttr(elementType, -1));
Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
return success();
}
};
template <typename ComparisonOp, arith::CmpFPredicate p>
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@ -700,6 +743,72 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
}
};
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto type = op.getType().cast<ComplexType>();
Type elementType = type.getElementType();
Value arg = adaptor.getComplex();
Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
Value absLhs = b.create<math::AbsOp>(real);
Value absArg = b.create<complex::AbsOp>(elementType, arg);
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
Value sqrtAddAbs = b.create<math::SqrtOp>(addAbs);
Value sqrtAddAbsDivTwo = b.create<arith::DivFOp>(
sqrtAddAbs, b.create<arith::ConstantOp>(
elementType, b.getFloatAttr(elementType, 2)));
Value realIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
Value imagIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
Value resultReal = sqrtAddAbsDivTwo;
Value imagDivTwoResultReal = b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultReal, resultReal));
Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
Value resultImag = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
resultReal),
imagDivTwoResultReal);
resultReal = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultImag, resultImag)),
resultReal);
Value realIsZero =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
Value imagIsZero =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
}
};
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
@ -735,6 +844,7 @@ void mlir::populateComplexToStandardConversionPatterns(
// clang-format off
patterns.add<
AbsOpConversion,
Atan2OpConversion,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@ -748,7 +858,8 @@ void mlir::populateComplexToStandardConversionPatterns(
MulOpConversion,
NegOpConversion,
SignOpConversion,
SinOpConversion>(patterns.getContext());
SinOpConversion,
SqrtOpConversion>(patterns.getContext());
// clang-format on
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@ -14,6 +14,17 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: return %[[NORM]] : f32
// -----
// CHECK-LABEL: func @complex_atan2
func.func @complex_atan2(%lhs: complex<f32>,
%rhs: complex<f32>) -> complex<f32> {
%atan2 = complex.atan2 %lhs, %rhs : complex<f32>
return %atan2 : complex<f32>
}
// -----
// CHECK-LABEL: func @complex_add
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@ -29,6 +40,8 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_cos
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@ -50,6 +63,8 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
// -----
// CHECK-LABEL: func @complex_div
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@ -159,6 +174,8 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_eq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@ -174,6 +191,8 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: return %[[EQUAL]] : i1
// -----
// CHECK-LABEL: func @complex_exp
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@ -190,6 +209,8 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func.func @complex_expm1(
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@ -211,6 +232,8 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RES]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_log
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@ -230,6 +253,8 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_log1p
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@ -254,6 +279,8 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_mul
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@ -372,6 +399,8 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_neg
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@ -385,6 +414,8 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_neq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@ -400,6 +431,8 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
// CHECK: return %[[NOT_EQUAL]] : i1
// -----
// CHECK-LABEL: func @complex_sin
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@ -421,6 +454,8 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
// -----
// CHECK-LABEL: func @complex_sign
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@ -445,6 +480,8 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_sub
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@ -459,3 +496,11 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @complex_sqrt
func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
return %sqrt : complex<f32>
}