forked from OSchip/llvm-project
[AArch64] Fix MUL/SUB fusing
Summary: When MUL is the first operand to SUB, we can't use MLS because the accumulator should be negated. Emit a NEG of the accumulator and an MLA instead, similar to what we do for FMUL / FSUB fusing. Reviewers: dmgreen, SjoerdMeijer, fhahn, Gerolf, mstorsjo, asbirlea Reviewed By: asbirlea Subscribers: kristof.beyls, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D71067
This commit is contained in:
parent
fffd70291e
commit
e503fee904
|
@ -4198,6 +4198,40 @@ static MachineInstr *genFusedMultiplyAcc(
|
|||
FMAInstKind::Accumulator);
|
||||
}
|
||||
|
||||
/// genNeg - Helper to generate an intermediate negation of the second operand
|
||||
/// of Root
|
||||
static Register genNeg(MachineFunction &MF, MachineRegisterInfo &MRI,
|
||||
const TargetInstrInfo *TII, MachineInstr &Root,
|
||||
SmallVectorImpl<MachineInstr *> &InsInstrs,
|
||||
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
|
||||
unsigned MnegOpc, const TargetRegisterClass *RC) {
|
||||
Register NewVR = MRI.createVirtualRegister(RC);
|
||||
MachineInstrBuilder MIB =
|
||||
BuildMI(MF, Root.getDebugLoc(), TII->get(MnegOpc), NewVR)
|
||||
.add(Root.getOperand(2));
|
||||
InsInstrs.push_back(MIB);
|
||||
|
||||
assert(InstrIdxForVirtReg.empty());
|
||||
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
|
||||
|
||||
return NewVR;
|
||||
}
|
||||
|
||||
/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate
|
||||
/// instructions with an additional negation of the accumulator
|
||||
static MachineInstr *genFusedMultiplyAccNeg(
|
||||
MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII,
|
||||
MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
|
||||
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd,
|
||||
unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) {
|
||||
assert(IdxMulOpd == 1);
|
||||
|
||||
Register NewVR =
|
||||
genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC);
|
||||
return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC,
|
||||
FMAInstKind::Accumulator, &NewVR);
|
||||
}
|
||||
|
||||
/// genFusedMultiplyIdx - Helper to generate fused multiply accumulate
|
||||
/// instructions.
|
||||
///
|
||||
|
@ -4210,6 +4244,22 @@ static MachineInstr *genFusedMultiplyIdx(
|
|||
FMAInstKind::Indexed);
|
||||
}
|
||||
|
||||
/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate
|
||||
/// instructions with an additional negation of the accumulator
|
||||
static MachineInstr *genFusedMultiplyIdxNeg(
|
||||
MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII,
|
||||
MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
|
||||
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd,
|
||||
unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) {
|
||||
assert(IdxMulOpd == 1);
|
||||
|
||||
Register NewVR =
|
||||
genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC);
|
||||
|
||||
return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC,
|
||||
FMAInstKind::Indexed, &NewVR);
|
||||
}
|
||||
|
||||
/// genMaddR - Generate madd instruction and combine mul and add using
|
||||
/// an extra virtual register
|
||||
/// Example - an ADD intermediate needs to be stored in a register:
|
||||
|
@ -4512,9 +4562,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
break;
|
||||
|
||||
case MachineCombinerPattern::MULSUBv8i8_OP1:
|
||||
Opc = AArch64::MLSv8i8;
|
||||
Opc = AArch64::MLAv8i8;
|
||||
RC = &AArch64::FPR64RegClass;
|
||||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i8,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv8i8_OP2:
|
||||
Opc = AArch64::MLSv8i8;
|
||||
|
@ -4522,9 +4574,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv16i8_OP1:
|
||||
Opc = AArch64::MLSv16i8;
|
||||
Opc = AArch64::MLAv16i8;
|
||||
RC = &AArch64::FPR128RegClass;
|
||||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv16i8,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv16i8_OP2:
|
||||
Opc = AArch64::MLSv16i8;
|
||||
|
@ -4532,9 +4586,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i16_OP1:
|
||||
Opc = AArch64::MLSv4i16;
|
||||
Opc = AArch64::MLAv4i16;
|
||||
RC = &AArch64::FPR64RegClass;
|
||||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i16_OP2:
|
||||
Opc = AArch64::MLSv4i16;
|
||||
|
@ -4542,9 +4598,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv8i16_OP1:
|
||||
Opc = AArch64::MLSv8i16;
|
||||
Opc = AArch64::MLAv8i16;
|
||||
RC = &AArch64::FPR128RegClass;
|
||||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv8i16_OP2:
|
||||
Opc = AArch64::MLSv8i16;
|
||||
|
@ -4552,9 +4610,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv2i32_OP1:
|
||||
Opc = AArch64::MLSv2i32;
|
||||
Opc = AArch64::MLAv2i32;
|
||||
RC = &AArch64::FPR64RegClass;
|
||||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv2i32_OP2:
|
||||
Opc = AArch64::MLSv2i32;
|
||||
|
@ -4562,9 +4622,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i32_OP1:
|
||||
Opc = AArch64::MLSv4i32;
|
||||
Opc = AArch64::MLAv4i32;
|
||||
RC = &AArch64::FPR128RegClass;
|
||||
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i32_OP2:
|
||||
Opc = AArch64::MLSv4i32;
|
||||
|
@ -4614,9 +4676,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
break;
|
||||
|
||||
case MachineCombinerPattern::MULSUBv4i16_indexed_OP1:
|
||||
Opc = AArch64::MLSv4i16_indexed;
|
||||
Opc = AArch64::MLAv4i16_indexed;
|
||||
RC = &AArch64::FPR64RegClass;
|
||||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i16_indexed_OP2:
|
||||
Opc = AArch64::MLSv4i16_indexed;
|
||||
|
@ -4624,9 +4688,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv8i16_indexed_OP1:
|
||||
Opc = AArch64::MLSv8i16_indexed;
|
||||
Opc = AArch64::MLAv8i16_indexed;
|
||||
RC = &AArch64::FPR128RegClass;
|
||||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv8i16_indexed_OP2:
|
||||
Opc = AArch64::MLSv8i16_indexed;
|
||||
|
@ -4634,9 +4700,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv2i32_indexed_OP1:
|
||||
Opc = AArch64::MLSv2i32_indexed;
|
||||
Opc = AArch64::MLAv2i32_indexed;
|
||||
RC = &AArch64::FPR64RegClass;
|
||||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
|
||||
Opc = AArch64::MLSv2i32_indexed;
|
||||
|
@ -4644,9 +4712,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
|
|||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
|
||||
Opc = AArch64::MLSv4i32_indexed;
|
||||
Opc = AArch64::MLAv4i32_indexed;
|
||||
RC = &AArch64::FPR128RegClass;
|
||||
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
|
||||
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
|
||||
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32,
|
||||
RC);
|
||||
break;
|
||||
case MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
|
||||
Opc = AArch64::MLSv4i32_indexed;
|
||||
|
|
|
@ -135,3 +135,75 @@ define <4 x i32> @mls4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) {
|
|||
}
|
||||
|
||||
|
||||
define <8 x i8> @mls2v8xi8(<8 x i8> %A, <8 x i8> %B, <8 x i8> %C) {
|
||||
; CHECK-LABEL: mls2v8xi8:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: neg v2.8b, v2.8b
|
||||
; CHECK-NEXT: mla v2.8b, v0.8b, v1.8b
|
||||
; CHECK-NEXT: mov v0.16b, v2.16b
|
||||
; CHECK-NEXT: ret
|
||||
%tmp1 = mul <8 x i8> %A, %B;
|
||||
%tmp2 = sub <8 x i8> %tmp1, %C;
|
||||
ret <8 x i8> %tmp2
|
||||
}
|
||||
|
||||
define <16 x i8> @mls2v16xi8(<16 x i8> %A, <16 x i8> %B, <16 x i8> %C) {
|
||||
; CHECK-LABEL: mls2v16xi8:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: neg v2.16b, v2.16b
|
||||
; CHECK-NEXT: mla v2.16b, v0.16b, v1.16b
|
||||
; CHECK-NEXT: mov v0.16b, v2.16b
|
||||
; CHECK-NEXT: ret
|
||||
%tmp1 = mul <16 x i8> %A, %B;
|
||||
%tmp2 = sub <16 x i8> %tmp1, %C;
|
||||
ret <16 x i8> %tmp2
|
||||
}
|
||||
|
||||
define <4 x i16> @mls2v4xi16(<4 x i16> %A, <4 x i16> %B, <4 x i16> %C) {
|
||||
; CHECK-LABEL: mls2v4xi16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: neg v2.4h, v2.4h
|
||||
; CHECK-NEXT: mla v2.4h, v0.4h, v1.4h
|
||||
; CHECK-NEXT: mov v0.16b, v2.16b
|
||||
; CHECK-NEXT: ret
|
||||
%tmp1 = mul <4 x i16> %A, %B;
|
||||
%tmp2 = sub <4 x i16> %tmp1, %C;
|
||||
ret <4 x i16> %tmp2
|
||||
}
|
||||
|
||||
define <8 x i16> @mls2v8xi16(<8 x i16> %A, <8 x i16> %B, <8 x i16> %C) {
|
||||
; CHECK-LABEL: mls2v8xi16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: neg v2.8h, v2.8h
|
||||
; CHECK-NEXT: mla v2.8h, v0.8h, v1.8h
|
||||
; CHECK-NEXT: mov v0.16b, v2.16b
|
||||
; CHECK-NEXT: ret
|
||||
%tmp1 = mul <8 x i16> %A, %B;
|
||||
%tmp2 = sub <8 x i16> %tmp1, %C;
|
||||
ret <8 x i16> %tmp2
|
||||
}
|
||||
|
||||
define <2 x i32> @mls2v2xi32(<2 x i32> %A, <2 x i32> %B, <2 x i32> %C) {
|
||||
; CHECK-LABEL: mls2v2xi32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: neg v2.2s, v2.2s
|
||||
; CHECK-NEXT: mla v2.2s, v0.2s, v1.2s
|
||||
; CHECK-NEXT: mov v0.16b, v2.16b
|
||||
; CHECK-NEXT: ret
|
||||
%tmp1 = mul <2 x i32> %A, %B;
|
||||
%tmp2 = sub <2 x i32> %tmp1, %C;
|
||||
ret <2 x i32> %tmp2
|
||||
}
|
||||
|
||||
define <4 x i32> @mls2v4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) {
|
||||
; CHECK-LABEL: mls2v4xi32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: neg v2.4s, v2.4s
|
||||
; CHECK-NEXT: mla v2.4s, v0.4s, v1.4s
|
||||
; CHECK-NEXT: mov v0.16b, v2.16b
|
||||
; CHECK-NEXT: ret
|
||||
%tmp1 = mul <4 x i32> %A, %B;
|
||||
%tmp2 = sub <4 x i32> %tmp1, %C;
|
||||
ret <4 x i32> %tmp2
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue