diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index a676e0afe9a8..e1eca6181dff 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -413,6 +413,30 @@ struct ExpOpConversion : public OpConversionPattern { } }; +struct Expm1OpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = adaptor.getComplex().getType().cast(); + auto elementType = type.getElementType().cast(); + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value exp = b.create(adaptor.getComplex()); + + Value real = b.create(elementType, exp); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); + Value realMinusOne = b.create(real, one); + Value imag = b.create(elementType, exp); + + rewriter.replaceOpWithNewOp(op, type, realMinusOne, + imag); + return success(); + } +}; + struct LogOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -718,6 +742,7 @@ void mlir::populateComplexToStandardConversionPatterns( CosOpConversion, DivOpConversion, ExpOpConversion, + Expm1OpConversion, LogOpConversion, Log1pOpConversion, MulOpConversion, diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 8e7098f832bb..6f57e722b520 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -190,6 +190,27 @@ func.func @complex_exp(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-LABEL: func.func @complex_expm1( +// CHECK-SAME: %[[ARG:.*]]: complex) -> complex { +func.func @complex_expm1(%arg: complex) -> complex { + %expm1 = complex.expm1 %arg: complex + return %expm1 : complex +} +// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32 +// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32 +// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32 +// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32 +// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32 +// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex +// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32 +// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex +// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex +// CHECK: return %[[RES]] : complex + // CHECK-LABEL: func @complex_log // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_log(%arg: complex) -> complex {