forked from OSchip/llvm-project
[AArch64][SVE] Combine predicated FMUL/FADD into FMA
Combine FADD and FMUL intrinsics into FMA when the result of the FMUL is an FADD operand with one only use and both use the same predicate. Differential Revision: https://reviews.llvm.org/D111638
This commit is contained in:
parent
8689f5e6e7
commit
fc28a2f8ce
|
@ -243,6 +243,9 @@ public:
|
|||
void operator|=(const FastMathFlags &OtherFlags) {
|
||||
Flags |= OtherFlags.Flags;
|
||||
}
|
||||
bool operator!=(const FastMathFlags &OtherFlags) const {
|
||||
return Flags != OtherFlags.Flags;
|
||||
}
|
||||
};
|
||||
|
||||
/// Utility class for floating point operations which can have
|
||||
|
|
|
@ -695,6 +695,45 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
|
|||
return None;
|
||||
}
|
||||
|
||||
static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
|
||||
IntrinsicInst &II) {
|
||||
// fold (fadd p a (fmul p b c)) -> (fma p a b c)
|
||||
Value *p, *FMul, *a, *b, *c;
|
||||
auto m_SVEFAdd = [](auto p, auto w, auto x) {
|
||||
return m_CombineOr(m_Intrinsic<Intrinsic::aarch64_sve_fadd>(p, w, x),
|
||||
m_Intrinsic<Intrinsic::aarch64_sve_fadd>(p, x, w));
|
||||
};
|
||||
auto m_SVEFMul = [](auto p, auto y, auto z) {
|
||||
return m_Intrinsic<Intrinsic::aarch64_sve_fmul>(p, y, z);
|
||||
};
|
||||
if (!match(&II, m_SVEFAdd(m_Value(p), m_Value(a),
|
||||
m_CombineAnd(m_Value(FMul),
|
||||
m_SVEFMul(m_Deferred(p), m_Value(b),
|
||||
m_Value(c))))))
|
||||
return None;
|
||||
|
||||
if (!FMul->hasOneUse())
|
||||
return None;
|
||||
|
||||
llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
|
||||
llvm::FastMathFlags FMulFlags = cast<CallInst>(FMul)->getFastMathFlags();
|
||||
// Don't combine when FMul & Fadd flags differ to prevent the loss of any
|
||||
// additional important flags
|
||||
if (FAddFlags != FMulFlags)
|
||||
return None;
|
||||
bool AllowReassoc = FAddFlags.allowReassoc() && FMulFlags.allowReassoc();
|
||||
bool AllowContract = FAddFlags.allowContract() && FMulFlags.allowContract();
|
||||
if (!AllowReassoc || !AllowContract)
|
||||
return None;
|
||||
|
||||
IRBuilder<> Builder(II.getContext());
|
||||
Builder.SetInsertPoint(&II);
|
||||
auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
|
||||
{II.getType()}, {p, a, b, c}, &II);
|
||||
FMLA->setFastMathFlags(FAddFlags);
|
||||
return IC.replaceInstUsesWith(II, FMLA);
|
||||
}
|
||||
|
||||
static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
|
||||
switch (Intrinsic) {
|
||||
case Intrinsic::aarch64_sve_fmul:
|
||||
|
@ -724,6 +763,14 @@ static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
|
|||
return IC.replaceInstUsesWith(II, BinOp);
|
||||
}
|
||||
|
||||
static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
|
||||
IntrinsicInst &II) {
|
||||
auto FMLA = instCombineSVEVectorFMLA(IC, II);
|
||||
if (FMLA)
|
||||
return FMLA;
|
||||
return instCombineSVEVectorBinOp(IC, II);
|
||||
}
|
||||
|
||||
static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
|
||||
IntrinsicInst &II) {
|
||||
auto *OpPredicate = II.getOperand(0);
|
||||
|
@ -901,6 +948,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
|
|||
case Intrinsic::aarch64_sve_fmul:
|
||||
return instCombineSVEVectorMul(IC, II);
|
||||
case Intrinsic::aarch64_sve_fadd:
|
||||
return instCombineSVEVectorFAdd(IC, II);
|
||||
case Intrinsic::aarch64_sve_fsub:
|
||||
return instCombineSVEVectorBinOp(IC, II);
|
||||
case Intrinsic::aarch64_sve_tbl:
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
; RUN: opt -S -instcombine < %s | FileCheck %s
|
||||
|
||||
target triple = "aarch64-unknown-linux-gnu"
|
||||
|
||||
define dso_local <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @combine_fmla
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: ret <vscale x 8 x half> %6
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_contract_flag_only(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_contract_flag_only
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %7 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
; CHECK-NEXT: ret <vscale x 8 x half> %7
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_reassoc_flag_only(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_reassoc_flag_only
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %7 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
; CHECK-NEXT: ret <vscale x 8 x half> %7
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_min_flags(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_min_flags
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: ret <vscale x 8 x half> %6
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_no_fast_flag(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_no_fast_flag
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %7 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
; CHECK-NEXT: ret <vscale x 8 x half> %7
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_no_fmul(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_no_fmul
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
; CHECK-NEXT: ret <vscale x 8 x half> %7
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_pred(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_neq_pred
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
|
||||
; CHECK-NEXT: %7 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %6)
|
||||
; CHECK-NEXT: %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %9 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %7, <vscale x 8 x half> %1, <vscale x 8 x half> %8)
|
||||
; ret <vscale x 8 x half> %9
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
|
||||
%7 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %6)
|
||||
%8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%9 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %7, <vscale x 8 x half> %1, <vscale x 8 x half> %8)
|
||||
ret <vscale x 8 x half> %9
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_two_fmul_uses(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
; CHECK-NEXT: %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %7, <vscale x 8 x half> %6)
|
||||
; ret <vscale x 8 x half> %8
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
%8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %7, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %8
|
||||
}
|
||||
|
||||
define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_flags(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @neg_combine_fmla_neq_flags
|
||||
; CHECK-NEXT: %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
; CHECK-NEXT: %6 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
; CHECK-NEXT: %7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
; ret <vscale x 8 x half> %7
|
||||
%5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
|
||||
%6 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
|
||||
%7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
|
||||
ret <vscale x 8 x half> %7
|
||||
}
|
||||
|
||||
declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
|
||||
declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
|
||||
declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
|
||||
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
|
||||
attributes #0 = { "target-features"="+sve" }
|
Loading…
Reference in New Issue