From 375a5cb648835db0b1eacfc921cbb04844b8b3b4 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 11 Aug 2022 15:56:07 +0200 Subject: [PATCH] Don't lower log1p(x) to log(1 + x). The latter has accuracy issues around 0. The lowering in MathToLLVM is kept for now. Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D131676 --- .../mlir/Conversion/MathToLibm/MathToLibm.h | 6 +++-- .../ComplexToStandard/ComplexToStandard.cpp | 21 ++++++++++++++-- mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 7 ++++-- .../convert-to-standard.mlir | 24 +++++++++---------- .../MathToLibm/convert-to-libm.mlir | 12 ++++++++++ 5 files changed, 52 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h index 9e7aa1a0f52a..c07dcfd090d2 100644 --- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h +++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h @@ -15,8 +15,10 @@ template class OperationPass; /// Populate the given list with patterns that convert from Math to Libm calls. -void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit); +/// If log1pBenefit is present, use it instead of benefit for the Log1p op. +void populateMathToLibmConversionPatterns( + RewritePatternSet &patterns, PatternBenefit benefit, + llvm::Optional log1pBenefit = llvm::None); /// Create a pass to convert Math operations to libm calls. std::unique_ptr> createConvertMathToLibmPass(); diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 643806f2b0fa..064b0db08a41 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -513,11 +513,28 @@ struct Log1pOpConversion : public OpConversionPattern { Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); + + Value half = b.create(elementType, + b.getFloatAttr(elementType, 0.5)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); + Value two = b.create(elementType, + b.getFloatAttr(elementType, 2)); + + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) + // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) + Value sumSq = b.create(real, real); + sumSq = b.create(sumSq, b.create(real, two)); + sumSq = b.create(sumSq, b.create(imag, imag)); + Value logSumSq = b.create(elementType, sumSq); + Value resultReal = b.create(logSumSq, half); + Value realPlusOne = b.create(real, one); - Value newComplex = b.create(type, realPlusOne, imag); - rewriter.replaceOpWithNewOp(op, type, newComplex); + + Value resultImag = b.create(elementType, imag, realPlusOne); + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); return success(); } }; diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index 2cd5ca08395a..43ce675da926 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -138,8 +138,9 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, return success(); } -void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit) { +void mlir::populateMathToLibmConversionPatterns( + RewritePatternSet &patterns, PatternBenefit benefit, + llvm::Optional log1pBenefit) { patterns.add, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, @@ -168,6 +169,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, "cos", benefit); patterns.add>(patterns.getContext(), "sinf", "sin", benefit); + patterns.add>( + patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit)); } namespace { diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index cac758a89b61..e11187af14b8 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\ -// RUN: FileCheck %s +// RUN: FileCheck %s --dump-input=always // CHECK-LABEL: func @complex_abs // CHECK-SAME: %[[ARG:.*]]: complex @@ -262,21 +262,21 @@ func.func @complex_log1p(%arg: complex) -> complex { %log1p = complex.log1p %arg: complex return %log1p : complex } + // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32 +// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32 +// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32 +// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32 // CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32 -// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex -// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex -// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 -// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex -// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex -// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32 +// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir index ced15f571a40..f67c994a3b78 100644 --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -303,3 +303,15 @@ func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec %double_result = math.tan %double : vector<2xf64> return %float_result, %double_result : vector<2xf32>, vector<2xf64> } + +// CHECK-LABEL: func @log1p_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @log1p_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @log1pf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.log1p %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @log1p(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.log1p %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} \ No newline at end of file