forked from OSchip/llvm-project
Don't lower log1p(x) to log(1 + x).
The latter has accuracy issues around 0. The lowering in MathToLLVM is kept for now. Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D131676
This commit is contained in:
parent
556efdba85
commit
375a5cb648
|
@ -15,8 +15,10 @@ template <typename T>
|
|||
class OperationPass;
|
||||
|
||||
/// Populate the given list with patterns that convert from Math to Libm calls.
|
||||
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
|
||||
void populateMathToLibmConversionPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit,
|
||||
llvm::Optional<PatternBenefit> log1pBenefit = llvm::None);
|
||||
|
||||
/// Create a pass to convert Math operations to libm calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
|
||||
|
|
|
@ -513,11 +513,28 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
|
|||
|
||||
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
|
||||
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
|
||||
|
||||
Value half = b.create<arith::ConstantOp>(elementType,
|
||||
b.getFloatAttr(elementType, 0.5));
|
||||
Value one = b.create<arith::ConstantOp>(elementType,
|
||||
b.getFloatAttr(elementType, 1));
|
||||
Value two = b.create<arith::ConstantOp>(elementType,
|
||||
b.getFloatAttr(elementType, 2));
|
||||
|
||||
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
|
||||
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
|
||||
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
|
||||
Value sumSq = b.create<arith::MulFOp>(real, real);
|
||||
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
|
||||
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
|
||||
Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
|
||||
Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
|
||||
|
||||
Value realPlusOne = b.create<arith::AddFOp>(real, one);
|
||||
Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
|
||||
rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
|
||||
|
||||
Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
|
||||
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
||||
resultImag);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -138,8 +138,9 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
void mlir::populateMathToLibmConversionPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit,
|
||||
llvm::Optional<PatternBenefit> log1pBenefit) {
|
||||
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
|
||||
VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
|
||||
VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
|
||||
|
@ -168,6 +169,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
|
|||
"cos", benefit);
|
||||
patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
|
||||
"sin", benefit);
|
||||
patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
|
||||
patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
|
||||
// RUN: FileCheck %s
|
||||
// RUN: FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @complex_abs
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
|
@ -262,21 +262,21 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
|
|||
%log1p = complex.log1p %arg: complex<f32>
|
||||
return %log1p : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
|
||||
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
|
||||
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
|
||||
// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
|
||||
// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
|
||||
// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
|
||||
// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
|
||||
// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
|
||||
// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
|
||||
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
|
||||
// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex<f32>
|
||||
// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
|
||||
// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
|
||||
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
|
||||
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
|
||||
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
|
||||
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
|
||||
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
|
||||
// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
|
||||
// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
|
||||
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
|
||||
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
|
||||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
|
|
|
@ -303,3 +303,15 @@ func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
|
|||
%double_result = math.tan %double : vector<2xf64>
|
||||
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @log1p_caller
|
||||
// CHECK-SAME: %[[FLOAT:.*]]: f32
|
||||
// CHECK-SAME: %[[DOUBLE:.*]]: f64
|
||||
func.func @log1p_caller(%float: f32, %double: f64) -> (f32, f64) {
|
||||
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @log1pf(%[[FLOAT]]) : (f32) -> f32
|
||||
%float_result = math.log1p %float : f32
|
||||
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @log1p(%[[DOUBLE]]) : (f64) -> f64
|
||||
%double_result = math.log1p %double : f64
|
||||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
|
||||
return %float_result, %double_result : f32, f64
|
||||
}
|
Loading…
Reference in New Issue