diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 0b2c04bea763..df883d35fa18 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2426,34 +2426,20 @@ bool AArch64InstrInfo::hasPattern( static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, MachineInstr &Root, SmallVectorImpl &InsInstrs, - unsigned IdxMulOpd, unsigned MaddOpc, - const TargetRegisterClass *RC) { + unsigned IdxMulOpd, unsigned MaddOpc) { assert(IdxMulOpd == 1 || IdxMulOpd == 2); unsigned IdxOtherOpd = IdxMulOpd == 1 ? 2 : 1; MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg()); - unsigned ResultReg = Root.getOperand(0).getReg(); - unsigned SrcReg0 = MUL->getOperand(1).getReg(); - bool Src0IsKill = MUL->getOperand(1).isKill(); - unsigned SrcReg1 = MUL->getOperand(2).getReg(); - bool Src1IsKill = MUL->getOperand(2).isKill(); - unsigned SrcReg2 = Root.getOperand(IdxOtherOpd).getReg(); - bool Src2IsKill = Root.getOperand(IdxOtherOpd).isKill(); - - if (TargetRegisterInfo::isVirtualRegister(ResultReg)) - MRI.constrainRegClass(ResultReg, RC); - if (TargetRegisterInfo::isVirtualRegister(SrcReg0)) - MRI.constrainRegClass(SrcReg0, RC); - if (TargetRegisterInfo::isVirtualRegister(SrcReg1)) - MRI.constrainRegClass(SrcReg1, RC); - if (TargetRegisterInfo::isVirtualRegister(SrcReg2)) - MRI.constrainRegClass(SrcReg2, RC); - - MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), - ResultReg) - .addReg(SrcReg0, getKillRegState(Src0IsKill)) - .addReg(SrcReg1, getKillRegState(Src1IsKill)) - .addReg(SrcReg2, getKillRegState(Src2IsKill)); + MachineOperand R = Root.getOperand(0); + MachineOperand A = MUL->getOperand(1); + MachineOperand B = MUL->getOperand(2); + MachineOperand C = Root.getOperand(IdxOtherOpd); + MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc)) + .addOperand(R) + .addOperand(A) + .addOperand(B) + .addOperand(C); // Insert the MADD InsInstrs.push_back(MIB); return MUL; @@ -2478,35 +2464,22 @@ static MachineInstr *genMaddR(MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, MachineInstr &Root, SmallVectorImpl &InsInstrs, unsigned IdxMulOpd, unsigned MaddOpc, - unsigned VR, const TargetRegisterClass *RC) { + unsigned VR) { assert(IdxMulOpd == 1 || IdxMulOpd == 2); MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg()); - unsigned ResultReg = Root.getOperand(0).getReg(); - unsigned SrcReg0 = MUL->getOperand(1).getReg(); - bool Src0IsKill = MUL->getOperand(1).isKill(); - unsigned SrcReg1 = MUL->getOperand(2).getReg(); - bool Src1IsKill = MUL->getOperand(2).isKill(); - - if (TargetRegisterInfo::isVirtualRegister(ResultReg)) - MRI.constrainRegClass(ResultReg, RC); - if (TargetRegisterInfo::isVirtualRegister(SrcReg0)) - MRI.constrainRegClass(SrcReg0, RC); - if (TargetRegisterInfo::isVirtualRegister(SrcReg1)) - MRI.constrainRegClass(SrcReg1, RC); - if (TargetRegisterInfo::isVirtualRegister(VR)) - MRI.constrainRegClass(VR, RC); - - MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), - ResultReg) - .addReg(SrcReg0, getKillRegState(Src0IsKill)) - .addReg(SrcReg1, getKillRegState(Src1IsKill)) + MachineOperand R = Root.getOperand(0); + MachineOperand A = MUL->getOperand(1); + MachineOperand B = MUL->getOperand(2); + MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc)) + .addOperand(R) + .addOperand(A) + .addOperand(B) .addReg(VR); // Insert the MADD InsInstrs.push_back(MIB); return MUL; } - /// genAlternativeCodeSequence - when hasPattern() finds a pattern /// this function generates the instructions that could replace the /// original code sequence @@ -2521,7 +2494,6 @@ void AArch64InstrInfo::genAlternativeCodeSequence( const TargetInstrInfo *TII = MF.getTarget().getSubtargetImpl()->getInstrInfo(); MachineInstr *MUL; - const TargetRegisterClass *RC = nullptr; unsigned Opc; switch (Pattern) { default: @@ -2535,11 +2507,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // --- Create(MADD); Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP1 ? AArch64::MADDWrrr : AArch64::MADDXrrr; - if (Pattern == MachineCombinerPattern::MC_MULADDW_OP1) - RC = &AArch64::GPR32RegClass; - else - RC = &AArch64::GPR64RegClass; - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc); break; case MachineCombinerPattern::MC_MULADDW_OP2: case MachineCombinerPattern::MC_MULADDX_OP2: @@ -2549,56 +2517,52 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // --- Create(MADD); Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP2 ? AArch64::MADDWrrr : AArch64::MADDXrrr; - if (Pattern == MachineCombinerPattern::MC_MULADDW_OP2) - RC = &AArch64::GPR32RegClass; - else - RC = &AArch64::GPR64RegClass; - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc); break; case MachineCombinerPattern::MC_MULADDWI_OP1: - case MachineCombinerPattern::MC_MULADDXI_OP1: { + case MachineCombinerPattern::MC_MULADDXI_OP1: // MUL I=A,B,0 // ADD R,I,Imm // ==> ORR V, ZR, Imm // ==> MADD R,A,B,V // --- Create(MADD); - const TargetRegisterClass *OrrRC = nullptr; - unsigned BitSize, OrrOpc, ZeroReg; - if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) { - OrrOpc = AArch64::ORRWri; - OrrRC = &AArch64::GPR32spRegClass; - BitSize = 32; - ZeroReg = AArch64::WZR; - Opc = AArch64::MADDWrrr; - RC = &AArch64::GPR32RegClass; - } else { - OrrOpc = AArch64::ORRXri; - OrrRC = &AArch64::GPR64spRegClass; - BitSize = 64; - ZeroReg = AArch64::XZR; - Opc = AArch64::MADDXrrr; - RC = &AArch64::GPR64RegClass; - } - unsigned NewVR = MRI.createVirtualRegister(OrrRC); - uint64_t Imm = Root.getOperand(2).getImm(); + { + const TargetRegisterClass *RC = + MRI.getRegClass(Root.getOperand(1).getReg()); + unsigned NewVR = MRI.createVirtualRegister(RC); + unsigned BitSize, OrrOpc, ZeroReg; + if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) { + BitSize = 32; + OrrOpc = AArch64::ORRWri; + ZeroReg = AArch64::WZR; + Opc = AArch64::MADDWrrr; + } else { + OrrOpc = AArch64::ORRXri; + BitSize = 64; + ZeroReg = AArch64::XZR; + Opc = AArch64::MADDXrrr; + } + uint64_t Imm = Root.getOperand(2).getImm(); - if (Root.getOperand(3).isImm()) { - unsigned Val = Root.getOperand(3).getImm(); - Imm = Imm << Val; - } - uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize); - uint64_t Encoding; - if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { - MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) - .addReg(ZeroReg) - .addImm(Encoding); - InsInstrs.push_back(MIB1); - InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); + if (Root.getOperand(3).isImm()) { + unsigned val = Root.getOperand(3).getImm(); + Imm = Imm << val; + } + uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize); + uint64_t Encoding; + + if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc)) + .addOperand(MachineOperand::CreateReg(NewVR, true)) + .addReg(ZeroReg) + .addImm(Encoding); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); + } } break; - } case MachineCombinerPattern::MC_MULSUBW_OP1: case MachineCombinerPattern::MC_MULSUBX_OP1: { // MUL I=A,B,0 @@ -2606,32 +2570,29 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // ==> SUB V, 0, C // ==> MADD R,A,B,V // = -C + A*B // --- Create(MADD); - const TargetRegisterClass *SubRC = nullptr; + const TargetRegisterClass *RC = + MRI.getRegClass(Root.getOperand(1).getReg()); + unsigned NewVR = MRI.createVirtualRegister(RC); unsigned SubOpc, ZeroReg; if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP1) { SubOpc = AArch64::SUBWrr; - SubRC = &AArch64::GPR32spRegClass; ZeroReg = AArch64::WZR; Opc = AArch64::MADDWrrr; - RC = &AArch64::GPR32RegClass; } else { SubOpc = AArch64::SUBXrr; - SubRC = &AArch64::GPR64spRegClass; ZeroReg = AArch64::XZR; Opc = AArch64::MADDXrrr; - RC = &AArch64::GPR64RegClass; } - unsigned NewVR = MRI.createVirtualRegister(SubRC); // SUB NewVR, 0, C MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc), NewVR) + BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc)) + .addOperand(MachineOperand::CreateReg(NewVR, true)) .addReg(ZeroReg) .addOperand(Root.getOperand(2)); InsInstrs.push_back(MIB1); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); - break; - } + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); + } break; case MachineCombinerPattern::MC_MULSUBW_OP2: case MachineCombinerPattern::MC_MULSUBX_OP2: // MUL I=A,B,0 @@ -2640,11 +2601,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // --- Create(MSUB); Opc = Pattern == MachineCombinerPattern::MC_MULSUBW_OP2 ? AArch64::MSUBWrrr : AArch64::MSUBXrrr; - if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP2) - RC = &AArch64::GPR32RegClass; - else - RC = &AArch64::GPR64RegClass; - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc); break; case MachineCombinerPattern::MC_MULSUBWI_OP1: case MachineCombinerPattern::MC_MULSUBXI_OP1: { @@ -2653,43 +2610,40 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // ==> ORR V, ZR, -Imm // ==> MADD R,A,B,V // = -Imm + A*B // --- Create(MADD); - const TargetRegisterClass *OrrRC = nullptr; + const TargetRegisterClass *RC = + MRI.getRegClass(Root.getOperand(1).getReg()); + unsigned NewVR = MRI.createVirtualRegister(RC); unsigned BitSize, OrrOpc, ZeroReg; if (Pattern == MachineCombinerPattern::MC_MULSUBWI_OP1) { - OrrOpc = AArch64::ORRWri; - RC = &AArch64::GPR32spRegClass; BitSize = 32; + OrrOpc = AArch64::ORRWri; ZeroReg = AArch64::WZR; Opc = AArch64::MADDWrrr; - RC = &AArch64::GPR32RegClass; } else { OrrOpc = AArch64::ORRXri; - RC = &AArch64::GPR64RegClass; BitSize = 64; ZeroReg = AArch64::XZR; Opc = AArch64::MADDXrrr; - RC = &AArch64::GPR64RegClass; } - unsigned NewVR = MRI.createVirtualRegister(OrrRC); int Imm = Root.getOperand(2).getImm(); if (Root.getOperand(3).isImm()) { - unsigned Val = Root.getOperand(3).getImm(); - Imm = Imm << Val; + unsigned val = Root.getOperand(3).getImm(); + Imm = Imm << val; } uint64_t UImm = -Imm << (64 - BitSize) >> (64 - BitSize); uint64_t Encoding; if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) + BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc)) + .addOperand(MachineOperand::CreateReg(NewVR, true)) .addReg(ZeroReg) .addImm(Encoding); InsInstrs.push_back(MIB1); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); } - break; + } break; } - } // end switch (Pattern) // Record MUL and ADD/SUB for deletion DelInstrs.push_back(MUL); DelInstrs.push_back(&Root); diff --git a/llvm/test/CodeGen/AArch64/madd-combiner.ll b/llvm/test/CodeGen/AArch64/madd-combiner.ll deleted file mode 100644 index e7f171b90487..000000000000 --- a/llvm/test/CodeGen/AArch64/madd-combiner.ll +++ /dev/null @@ -1,20 +0,0 @@ -; RUN: llc -mtriple=aarch64-apple-darwin -verify-machineinstrs < %s | FileCheck %s - -; Test that we use the correct register class. -define i32 @mul_add_imm(i32 %a, i32 %b) { -; CHECK-LABEL: mul_add_imm -; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4 -; CHECK-NEXT: madd {{w[0-9]+}}, w0, w1, [[REG]] - %1 = mul i32 %a, %b - %2 = add i32 %1, 4 - ret i32 %2 -} - -define i32 @mul_sub_imm1(i32 %a, i32 %b) { -; CHECK-LABEL: mul_sub_imm1 -; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4 -; CHECK-NEXT: msub {{w[0-9]+}}, w0, w1, [[REG]] - %1 = mul i32 %a, %b - %2 = sub i32 4, %1 - ret i32 %2 -}