[AArch64] Fix MUL/SUB fusing

Summary:
When MUL is the first operand to SUB, we can't use MLS because the accumulator
should be negated.  Emit a NEG of the accumulator and an MLA instead, similar to
what we do for FMUL / FSUB fusing.

Reviewers: dmgreen, SjoerdMeijer, fhahn, Gerolf, mstorsjo, asbirlea

Reviewed By: asbirlea

Subscribers: kristof.beyls, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D71067
This commit is contained in:
Sanne Wouda 2019-12-05 18:09:08 +00:00
parent fffd70291e
commit e503fee904
2 changed files with 162 additions and 20 deletions

View File

@ -4198,6 +4198,40 @@ static MachineInstr *genFusedMultiplyAcc(
FMAInstKind::Accumulator);
}
/// genNeg - Helper to generate an intermediate negation of the second operand
/// of Root
static Register genNeg(MachineFunction &MF, MachineRegisterInfo &MRI,
const TargetInstrInfo *TII, MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
unsigned MnegOpc, const TargetRegisterClass *RC) {
Register NewVR = MRI.createVirtualRegister(RC);
MachineInstrBuilder MIB =
BuildMI(MF, Root.getDebugLoc(), TII->get(MnegOpc), NewVR)
.add(Root.getOperand(2));
InsInstrs.push_back(MIB);
assert(InstrIdxForVirtReg.empty());
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
return NewVR;
}
/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate
/// instructions with an additional negation of the accumulator
static MachineInstr *genFusedMultiplyAccNeg(
MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII,
MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd,
unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) {
assert(IdxMulOpd == 1);
Register NewVR =
genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC);
return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC,
FMAInstKind::Accumulator, &NewVR);
}
/// genFusedMultiplyIdx - Helper to generate fused multiply accumulate
/// instructions.
///
@ -4210,6 +4244,22 @@ static MachineInstr *genFusedMultiplyIdx(
FMAInstKind::Indexed);
}
/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate
/// instructions with an additional negation of the accumulator
static MachineInstr *genFusedMultiplyIdxNeg(
MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII,
MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd,
unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) {
assert(IdxMulOpd == 1);
Register NewVR =
genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC);
return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC,
FMAInstKind::Indexed, &NewVR);
}
/// genMaddR - Generate madd instruction and combine mul and add using
/// an extra virtual register
/// Example - an ADD intermediate needs to be stored in a register:
@ -4512,9 +4562,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
break;
case MachineCombinerPattern::MULSUBv8i8_OP1:
Opc = AArch64::MLSv8i8;
Opc = AArch64::MLAv8i8;
RC = &AArch64::FPR64RegClass;
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i8,
RC);
break;
case MachineCombinerPattern::MULSUBv8i8_OP2:
Opc = AArch64::MLSv8i8;
@ -4522,9 +4574,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv16i8_OP1:
Opc = AArch64::MLSv16i8;
Opc = AArch64::MLAv16i8;
RC = &AArch64::FPR128RegClass;
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv16i8,
RC);
break;
case MachineCombinerPattern::MULSUBv16i8_OP2:
Opc = AArch64::MLSv16i8;
@ -4532,9 +4586,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv4i16_OP1:
Opc = AArch64::MLSv4i16;
Opc = AArch64::MLAv4i16;
RC = &AArch64::FPR64RegClass;
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16,
RC);
break;
case MachineCombinerPattern::MULSUBv4i16_OP2:
Opc = AArch64::MLSv4i16;
@ -4542,9 +4598,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv8i16_OP1:
Opc = AArch64::MLSv8i16;
Opc = AArch64::MLAv8i16;
RC = &AArch64::FPR128RegClass;
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16,
RC);
break;
case MachineCombinerPattern::MULSUBv8i16_OP2:
Opc = AArch64::MLSv8i16;
@ -4552,9 +4610,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv2i32_OP1:
Opc = AArch64::MLSv2i32;
Opc = AArch64::MLAv2i32;
RC = &AArch64::FPR64RegClass;
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32,
RC);
break;
case MachineCombinerPattern::MULSUBv2i32_OP2:
Opc = AArch64::MLSv2i32;
@ -4562,9 +4622,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv4i32_OP1:
Opc = AArch64::MLSv4i32;
Opc = AArch64::MLAv4i32;
RC = &AArch64::FPR128RegClass;
MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32,
RC);
break;
case MachineCombinerPattern::MULSUBv4i32_OP2:
Opc = AArch64::MLSv4i32;
@ -4614,9 +4676,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
break;
case MachineCombinerPattern::MULSUBv4i16_indexed_OP1:
Opc = AArch64::MLSv4i16_indexed;
Opc = AArch64::MLAv4i16_indexed;
RC = &AArch64::FPR64RegClass;
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16,
RC);
break;
case MachineCombinerPattern::MULSUBv4i16_indexed_OP2:
Opc = AArch64::MLSv4i16_indexed;
@ -4624,9 +4688,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv8i16_indexed_OP1:
Opc = AArch64::MLSv8i16_indexed;
Opc = AArch64::MLAv8i16_indexed;
RC = &AArch64::FPR128RegClass;
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16,
RC);
break;
case MachineCombinerPattern::MULSUBv8i16_indexed_OP2:
Opc = AArch64::MLSv8i16_indexed;
@ -4634,9 +4700,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv2i32_indexed_OP1:
Opc = AArch64::MLSv2i32_indexed;
Opc = AArch64::MLAv2i32_indexed;
RC = &AArch64::FPR64RegClass;
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32,
RC);
break;
case MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
Opc = AArch64::MLSv2i32_indexed;
@ -4644,9 +4712,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
Opc = AArch64::MLSv4i32_indexed;
Opc = AArch64::MLAv4i32_indexed;
RC = &AArch64::FPR128RegClass;
MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32,
RC);
break;
case MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
Opc = AArch64::MLSv4i32_indexed;

View File

@ -135,3 +135,75 @@ define <4 x i32> @mls4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) {
}
define <8 x i8> @mls2v8xi8(<8 x i8> %A, <8 x i8> %B, <8 x i8> %C) {
; CHECK-LABEL: mls2v8xi8:
; CHECK: // %bb.0:
; CHECK-NEXT: neg v2.8b, v2.8b
; CHECK-NEXT: mla v2.8b, v0.8b, v1.8b
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%tmp1 = mul <8 x i8> %A, %B;
%tmp2 = sub <8 x i8> %tmp1, %C;
ret <8 x i8> %tmp2
}
define <16 x i8> @mls2v16xi8(<16 x i8> %A, <16 x i8> %B, <16 x i8> %C) {
; CHECK-LABEL: mls2v16xi8:
; CHECK: // %bb.0:
; CHECK-NEXT: neg v2.16b, v2.16b
; CHECK-NEXT: mla v2.16b, v0.16b, v1.16b
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%tmp1 = mul <16 x i8> %A, %B;
%tmp2 = sub <16 x i8> %tmp1, %C;
ret <16 x i8> %tmp2
}
define <4 x i16> @mls2v4xi16(<4 x i16> %A, <4 x i16> %B, <4 x i16> %C) {
; CHECK-LABEL: mls2v4xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: neg v2.4h, v2.4h
; CHECK-NEXT: mla v2.4h, v0.4h, v1.4h
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%tmp1 = mul <4 x i16> %A, %B;
%tmp2 = sub <4 x i16> %tmp1, %C;
ret <4 x i16> %tmp2
}
define <8 x i16> @mls2v8xi16(<8 x i16> %A, <8 x i16> %B, <8 x i16> %C) {
; CHECK-LABEL: mls2v8xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: neg v2.8h, v2.8h
; CHECK-NEXT: mla v2.8h, v0.8h, v1.8h
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%tmp1 = mul <8 x i16> %A, %B;
%tmp2 = sub <8 x i16> %tmp1, %C;
ret <8 x i16> %tmp2
}
define <2 x i32> @mls2v2xi32(<2 x i32> %A, <2 x i32> %B, <2 x i32> %C) {
; CHECK-LABEL: mls2v2xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: neg v2.2s, v2.2s
; CHECK-NEXT: mla v2.2s, v0.2s, v1.2s
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%tmp1 = mul <2 x i32> %A, %B;
%tmp2 = sub <2 x i32> %tmp1, %C;
ret <2 x i32> %tmp2
}
define <4 x i32> @mls2v4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) {
; CHECK-LABEL: mls2v4xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: neg v2.4s, v2.4s
; CHECK-NEXT: mla v2.4s, v0.4s, v1.4s
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%tmp1 = mul <4 x i32> %A, %B;
%tmp2 = sub <4 x i32> %tmp1, %C;
ret <4 x i32> %tmp2
}