[AArch64][SVE] Combine FADD and FMUL aarch64 intrinsics to FMLA

This is a refinement to the work in
https://reviews.llvm.org/D111638

Fold (fadd p a (fmul p b c)) into (fma p a b c)

Differential Revision: https://reviews.llvm.org/D113095
This commit is contained in:
Matt 2021-11-03 11:31:41 +00:00 committed by Matt Devereau
parent db289340c8
commit 4a59694ba1
3 changed files with 149 additions and 0 deletions

View File

@ -247,6 +247,9 @@ public:
void operator|=(const FastMathFlags &OtherFlags) {
Flags |= OtherFlags.Flags;
}
bool operator!=(const FastMathFlags &OtherFlags) const {
return Flags != OtherFlags.Flags;
}
};
/// Utility class for floating point operations which can have

View File

@ -695,6 +695,36 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
return None;
}
static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
IntrinsicInst &II) {
// fold (fadd p a (fmul p b c)) -> (fma p a b c)
Value *P = II.getOperand(0);
Value *A = II.getOperand(1);
auto FMul = II.getOperand(2);
Value *B, *C;
if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
m_Specific(P), m_Value(B), m_Value(C))))
return None;
if (!FMul->hasOneUse())
return None;
llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
// Stop the combine when the flags on the inputs differ in case dropping flags
// would lead to us missing out on more beneficial optimizations.
if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
return None;
if (!FAddFlags.allowContract())
return None;
IRBuilder<> Builder(II.getContext());
Builder.SetInsertPoint(&II);
auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
{II.getType()}, {P, A, B, C}, &II);
FMLA->setFastMathFlags(FAddFlags);
return IC.replaceInstUsesWith(II, FMLA);
}
static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
switch (Intrinsic) {
case Intrinsic::aarch64_sve_fmul:
@ -724,6 +754,13 @@ static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
return IC.replaceInstUsesWith(II, BinOp);
}
static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
IntrinsicInst &II) {
if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
return FMLA;
return instCombineSVEVectorBinOp(IC, II);
}
static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
IntrinsicInst &II) {
auto *OpPredicate = II.getOperand(0);
@ -969,6 +1006,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_fmul:
return instCombineSVEVectorMul(IC, II);
case Intrinsic::aarch64_sve_fadd:
return instCombineSVEVectorFAdd(IC, II);
case Intrinsic::aarch64_sve_fsub:
return instCombineSVEVectorBinOp(IC, II);
case Intrinsic::aarch64_sve_tbl:

View File

@ -0,0 +1,108 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -instcombine < %s | FileCheck %s
target triple = "aarch64-unknown-linux-gnu"
define dso_local <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @combine_fmla(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP2]]
;
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %3
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_mul_first_operand(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_mul_first_operand(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[TMP2]], <vscale x 8 x half> [[A:%.*]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP3]]
;
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %a)
ret <vscale x 8 x half> %3
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_contract_flag_only(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_contract_flag_only(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = call contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP2]]
;
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%3 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %3
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_no_flags(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_no_flags(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP3]]
;
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%3 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %3
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_pred(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_neq_pred(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
; CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP2]])
; CHECK-NEXT: [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: [[TMP5:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP3]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP4]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP5]]
;
; ret <vscale x 8 x half> %9
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
%3 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
%4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%5 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %3, <vscale x 8 x half> %a, <vscale x 8 x half> %4)
ret <vscale x 8 x half> %5
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_two_fmul_uses(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
; CHECK-NEXT: [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[TMP3]], <vscale x 8 x half> [[TMP2]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP4]]
;
; ret <vscale x 8 x half> %8
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
%4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %3, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %4
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_flags(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_neq_flags(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
; CHECK-NEXT: [[TMP3:%.*]] = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP3]]
;
; ret <vscale x 8 x half> %7
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
%3 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %3
}
declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
attributes #0 = { "target-features"="+sve" }