[AArch64][SVE] Combine predicated FMUL/FADD into FMA

Combine FADD and FMUL intrinsics into FMA when the result of the FMUL is an FADD operand
with one only use and both use the same predicate.

Differential Revision: https://reviews.llvm.org/D111638
This commit is contained in:
Matt 2021-10-12 10:18:59 +00:00
parent 8689f5e6e7
commit fc28a2f8ce
3 changed files with 172 additions and 0 deletions

View File

@ -243,6 +243,9 @@ public:
void operator|=(const FastMathFlags &OtherFlags) {
Flags |= OtherFlags.Flags;
}
bool operator!=(const FastMathFlags &OtherFlags) const {
return Flags != OtherFlags.Flags;
}
};
/// Utility class for floating point operations which can have

View File

@ -695,6 +695,45 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
return None;
}
static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
IntrinsicInst &II) {
// fold (fadd p a (fmul p b c)) -> (fma p a b c)
Value *p, *FMul, *a, *b, *c;
auto m_SVEFAdd = [](auto p, auto w, auto x) {
return m_CombineOr(m_Intrinsic<Intrinsic::aarch64_sve_fadd>(p, w, x),
m_Intrinsic<Intrinsic::aarch64_sve_fadd>(p, x, w));
};
auto m_SVEFMul = [](auto p, auto y, auto z) {
return m_Intrinsic<Intrinsic::aarch64_sve_fmul>(p, y, z);
};
if (!match(&II, m_SVEFAdd(m_Value(p), m_Value(a),
m_CombineAnd(m_Value(FMul),
m_SVEFMul(m_Deferred(p), m_Value(b),
m_Value(c))))))
return None;
if (!FMul->hasOneUse())
return None;
llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
llvm::FastMathFlags FMulFlags = cast<CallInst>(FMul)->getFastMathFlags();
// Don't combine when FMul & Fadd flags differ to prevent the loss of any
// additional important flags
if (FAddFlags != FMulFlags)
return None;
bool AllowReassoc = FAddFlags.allowReassoc() && FMulFlags.allowReassoc();
bool AllowContract = FAddFlags.allowContract() && FMulFlags.allowContract();
if (!AllowReassoc || !AllowContract)
return None;
IRBuilder<> Builder(II.getContext());
Builder.SetInsertPoint(&II);
auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
{II.getType()}, {p, a, b, c}, &II);
FMLA->setFastMathFlags(FAddFlags);
return IC.replaceInstUsesWith(II, FMLA);
}
static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
switch (Intrinsic) {
case Intrinsic::aarch64_sve_fmul:
@ -724,6 +763,14 @@ static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
return IC.replaceInstUsesWith(II, BinOp);
}
static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
IntrinsicInst &II) {
auto FMLA = instCombineSVEVectorFMLA(IC, II);
if (FMLA)
return FMLA;
return instCombineSVEVectorBinOp(IC, II);
}
static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
IntrinsicInst &II) {
auto *OpPredicate = II.getOperand(0);
@ -901,6 +948,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_fmul:
return instCombineSVEVectorMul(IC, II);
case Intrinsic::aarch64_sve_fadd:
return instCombineSVEVectorFAdd(IC, II);
case Intrinsic::aarch64_sve_fsub:
return instCombineSVEVectorBinOp(IC, II);
case Intrinsic::aarch64_sve_tbl:

View File

@ -0,0 +1,121 @@
; RUN: opt -S -instcombine < %s | FileCheck %s
target triple = "aarch64-unknown-linux-gnu"
define dso_local <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @combine_fmla
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: ret <vscale x 8 x half> %6
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_contract_flag_only(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_contract_flag_only
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %7 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
; CHECK-NEXT: ret <vscale x 8 x half> %7
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_reassoc_flag_only(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_reassoc_flag_only
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %7 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
; CHECK-NEXT: ret <vscale x 8 x half> %7
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_min_flags(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_min_flags
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: ret <vscale x 8 x half> %6
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_no_fast_flag(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_no_fast_flag
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %7 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
; CHECK-NEXT: ret <vscale x 8 x half> %7
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_no_fmul(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_no_fmul
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
; CHECK-NEXT: ret <vscale x 8 x half> %7
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_pred(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_neq_pred
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
; CHECK-NEXT: %7 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %6)
; CHECK-NEXT: %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %9 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %7, <vscale x 8 x half> %1, <vscale x 8 x half> %8)
; ret <vscale x 8 x half> %9
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
%7 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %6)
%8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%9 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %7, <vscale x 8 x half> %1, <vscale x 8 x half> %8)
ret <vscale x 8 x half> %9
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_two_fmul_uses(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
; CHECK-NEXT: %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %7, <vscale x 8 x half> %6)
; ret <vscale x 8 x half> %8
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
%8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %7, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %8
}
define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_flags(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
; CHECK-LABEL: @neg_combine_fmla_neq_flags
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
; CHECK-NEXT: %6 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
; CHECK-NEXT: %7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
; ret <vscale x 8 x half> %7
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
%6 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
%7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
ret <vscale x 8 x half> %7
}
declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
attributes #0 = { "target-features"="+sve" }