forked from OSchip/llvm-project
[mlir] Add lowering from math.expm1 to LLVM.
Differential Revision: https://reviews.llvm.org/D96776
This commit is contained in:
parent
8f2948731e
commit
93537fabce
|
@ -2352,6 +2352,60 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
|
|||
}
|
||||
};
|
||||
|
||||
// A `expm1` is converted into `exp - 1`.
|
||||
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
|
||||
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
math::ExpM1Op::Adaptor transformed(operands);
|
||||
auto operandType = transformed.operand().getType();
|
||||
|
||||
if (!operandType || !LLVM::isCompatibleType(operandType))
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto resultType = op.getResult().getType();
|
||||
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
||||
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
||||
|
||||
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
||||
LLVM::ConstantOp one;
|
||||
if (LLVM::isCompatibleVectorType(operandType)) {
|
||||
one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, operandType,
|
||||
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
|
||||
} else {
|
||||
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
||||
}
|
||||
auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
|
||||
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto vectorType = resultType.dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
||||
|
||||
return handleMultidimensionalVectors(
|
||||
op.getOperation(), operands, *getTypeConverter(),
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
auto splatAttr = SplatElementsAttr::get(
|
||||
mlir::VectorType::get(
|
||||
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
||||
floatType),
|
||||
floatOne);
|
||||
auto one =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
||||
auto exp =
|
||||
rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
|
||||
return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
|
||||
},
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
// A `log1p` is converted into `log(1 + ...)`.
|
||||
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
|
||||
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
|
||||
|
@ -3924,6 +3978,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
|||
DivFOpLowering,
|
||||
ExpOpLowering,
|
||||
Exp2OpLowering,
|
||||
ExpM1OpLowering,
|
||||
FloorFOpLowering,
|
||||
FmaFOpLowering,
|
||||
GenericAtomicRMWOpLowering,
|
||||
|
|
|
@ -37,6 +37,18 @@ func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @expm1(
|
||||
// CHECK-SAME: f32
|
||||
func @expm1(%arg0 : f32) {
|
||||
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
|
||||
// CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32
|
||||
// CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
|
||||
%0 = math.expm1 %arg0 : f32
|
||||
std.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rsqrt(
|
||||
// CHECK-SAME: f32
|
||||
func @rsqrt(%arg0 : f32) {
|
||||
|
|
Loading…
Reference in New Issue