Don't lower log1p(x) to log(1 + x).

The latter has accuracy issues around 0. The lowering in MathToLLVM is kept for now.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D131676
This commit is contained in:
Johannes Reifferscheid 2022-08-11 15:56:07 +02:00
parent 556efdba85
commit 375a5cb648
5 changed files with 52 additions and 18 deletions

View File

@ -15,8 +15,10 @@ template <typename T>
class OperationPass;
/// Populate the given list with patterns that convert from Math to Libm calls.
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit);
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
void populateMathToLibmConversionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
llvm::Optional<PatternBenefit> log1pBenefit = llvm::None);
/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

View File

@ -513,11 +513,28 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
Value half = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 0.5));
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
Value two = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 2));
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
Value sumSq = b.create<arith::MulFOp>(real, real);
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
Value realPlusOne = b.create<arith::AddFOp>(real, one);
Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
}
};

View File

@ -138,8 +138,9 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
void mlir::populateMathToLibmConversionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
llvm::Optional<PatternBenefit> log1pBenefit) {
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
@ -168,6 +169,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
"cos", benefit);
patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
"sin", benefit);
patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
}
namespace {

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
// RUN: FileCheck %s
// RUN: FileCheck %s --dump-input=always
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@ -262,21 +262,21 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
%log1p = complex.log1p %arg: complex<f32>
return %log1p : complex<f32>
}
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex<f32>
// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>

View File

@ -303,3 +303,15 @@ func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
%double_result = math.tan %double : vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
// CHECK-LABEL: func @log1p_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func.func @log1p_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @log1pf(%[[FLOAT]]) : (f32) -> f32
%float_result = math.log1p %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @log1p(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.log1p %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}