diff --git a/llvm/lib/Target/PTX/PTXAsmPrinter.cpp b/llvm/lib/Target/PTX/PTXAsmPrinter.cpp index 97bfed07958c..f936d4bb3c4e 100644 --- a/llvm/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/llvm/lib/Target/PTX/PTXAsmPrinter.cpp @@ -16,6 +16,7 @@ #include "PTX.h" #include "PTXMachineFunctionInfo.h" +#include "PTXRegisterInfo.h" #include "PTXTargetMachine.h" #include "llvm/DerivedTypes.h" #include "llvm/Module.h" @@ -67,7 +68,7 @@ public: void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, const char *Modifier = 0); void printReturnOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, - const char *Modifier = 0); + const char *Modifier = 0); void printPredicateOperand(const MachineInstr *MI, raw_ostream &O); void printCall(const MachineInstr *MI, raw_ostream &O); @@ -217,19 +218,61 @@ void PTXAsmPrinter::EmitFunctionBodyStart() { const PTXMachineFunctionInfo *MFI = MF->getInfo(); - // Print local variable definition - for (PTXMachineFunctionInfo::reg_iterator - i = MFI->localVarRegBegin(), e = MFI->localVarRegEnd(); i != e; ++ i) { - unsigned reg = *i; + // Print register definitions + std::string regDefs; + unsigned numRegs; - std::string def = "\t.reg ."; - def += getRegisterTypeName(reg); - def += ' '; - def += getRegisterName(reg); - def += ';'; - OutStreamer.EmitRawText(Twine(def)); + // pred + numRegs = MFI->getNumRegistersForClass(PTX::RegPredRegisterClass); + if(numRegs > 0) { + regDefs += "\t.reg .pred %p<"; + regDefs += utostr(numRegs); + regDefs += ">;\n"; } + // i16 + numRegs = MFI->getNumRegistersForClass(PTX::RegI16RegisterClass); + if(numRegs > 0) { + regDefs += "\t.reg .b16 %rh<"; + regDefs += utostr(numRegs); + regDefs += ">;\n"; + } + + // i32 + numRegs = MFI->getNumRegistersForClass(PTX::RegI32RegisterClass); + if(numRegs > 0) { + regDefs += "\t.reg .b32 %r<"; + regDefs += utostr(numRegs); + regDefs += ">;\n"; + } + + // i64 + numRegs = MFI->getNumRegistersForClass(PTX::RegI64RegisterClass); + if(numRegs > 0) { + regDefs += "\t.reg .b64 %rd<"; + regDefs += utostr(numRegs); + regDefs += ">;\n"; + } + + // f32 + numRegs = MFI->getNumRegistersForClass(PTX::RegF32RegisterClass); + if(numRegs > 0) { + regDefs += "\t.reg .f32 %f<"; + regDefs += utostr(numRegs); + regDefs += ">;\n"; + } + + // f64 + numRegs = MFI->getNumRegistersForClass(PTX::RegF64RegisterClass); + if(numRegs > 0) { + regDefs += "\t.reg .f64 %fd<"; + regDefs += utostr(numRegs); + regDefs += ">;\n"; + } + + OutStreamer.EmitRawText(Twine(regDefs)); + + const MachineFrameInfo* FrameInfo = MF->getFrameInfo(); DEBUG(dbgs() << "Have " << FrameInfo->getNumObjects() << " frame object(s)\n"); @@ -332,6 +375,7 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) { void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS) { const MachineOperand &MO = MI->getOperand(opNum); + const PTXMachineFunctionInfo *MFI = MF->getInfo(); switch (MO.getType()) { default: @@ -347,7 +391,7 @@ void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, OS << *MO.getMBB()->getSymbol(); break; case MachineOperand::MO_Register: - OS << getRegisterName(MO.getReg()); + OS << MFI->getRegisterName(MO.getReg()); break; case MachineOperand::MO_FPImmediate: APInt constFP = MO.getFPImm()->getValueAPF().bitcastToAPInt(); @@ -466,7 +510,7 @@ void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) { if (gv->hasInitializer()) { - const Constant *C = gv->getInitializer(); + const Constant *C = gv->getInitializer(); if (const ConstantArray *CA = dyn_cast(C)) { decl += " = {"; @@ -577,6 +621,7 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) { unsigned reg = MI->getOperand(i).getReg(); int predOp = MI->getOperand(i+1).getImm(); + const PTXMachineFunctionInfo *MFI = MF->getInfo(); DEBUG(dbgs() << "predicate: (" << reg << ", " << predOp << ")\n"); @@ -584,7 +629,7 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) { O << '@'; if (predOp == PTX::PRED_NEGATE) O << '!'; - O << getRegisterName(reg); + O << MFI->getRegisterName(reg); } } diff --git a/llvm/lib/Target/PTX/PTXInstrInfo.cpp b/llvm/lib/Target/PTX/PTXInstrInfo.cpp index 3ea75b277379..4d4bde408815 100644 --- a/llvm/lib/Target/PTX/PTXInstrInfo.cpp +++ b/llvm/lib/Target/PTX/PTXInstrInfo.cpp @@ -16,6 +16,7 @@ #include "PTX.h" #include "PTXInstrInfo.h" #include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/Support/Debug.h" @@ -47,8 +48,13 @@ void PTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator I, DebugLoc DL, unsigned DstReg, unsigned SrcReg, bool KillSrc) const { - for (int i = 0, e = sizeof(map)/sizeof(map[0]); i != e; ++ i) { - if (map[i].cls->contains(DstReg, SrcReg)) { + + const MachineRegisterInfo& MRI = MBB.getParent()->getRegInfo(); + assert(MRI.getRegClass(SrcReg) == MRI.getRegClass(DstReg) && + "Invalid register copy between two register classes"); + + for (int i = 0, e = sizeof(map)/sizeof(map[0]); i != e; ++i) { + if (map[i].cls == MRI.getRegClass(SrcReg)) { const MCInstrDesc &MCID = get(map[i].opcode); MachineInstr *MI = BuildMI(MBB, I, DL, MCID, DstReg). addReg(SrcReg, getKillRegState(KillSrc)); diff --git a/llvm/lib/Target/PTX/PTXMFInfoExtract.cpp b/llvm/lib/Target/PTX/PTXMFInfoExtract.cpp index 6fe9e6c3f657..0a41520fcc20 100644 --- a/llvm/lib/Target/PTX/PTXMFInfoExtract.cpp +++ b/llvm/lib/Target/PTX/PTXMFInfoExtract.cpp @@ -83,6 +83,13 @@ bool PTXMFInfoExtract::runOnMachineFunction(MachineFunction &MF) { i != e; ++i) dbgs() << "Local Var Reg: " << *i << "\n";); + // Generate list of all virtual registers used in this function + for (unsigned i = 0; i < MRI.getNumVirtRegs(); ++i) { + unsigned Reg = TargetRegisterInfo::index2VirtReg(i); + const TargetRegisterClass *TRC = MRI.getRegClass(Reg); + MFI->addVirtualRegister(TRC, Reg); + } + return false; } diff --git a/llvm/lib/Target/PTX/PTXMachineFunctionInfo.h b/llvm/lib/Target/PTX/PTXMachineFunctionInfo.h index a3b0f324feb8..16e5e7ba7fa6 100644 --- a/llvm/lib/Target/PTX/PTXMachineFunctionInfo.h +++ b/llvm/lib/Target/PTX/PTXMachineFunctionInfo.h @@ -15,7 +15,10 @@ #define PTX_MACHINE_FUNCTION_INFO_H #include "PTX.h" +#include "PTXRegisterInfo.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/CodeGen/MachineFunction.h" namespace llvm { @@ -30,11 +33,25 @@ private: std::vector call_params; bool _isDoneAddArg; + typedef std::vector RegisterList; + typedef DenseMap RegisterMap; + typedef DenseMap RegisterNameMap; + + RegisterMap usedRegs; + RegisterNameMap regNames; + public: PTXMachineFunctionInfo(MachineFunction &MF) : is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) { reg_arg.reserve(8); reg_local_var.reserve(32); + + usedRegs[PTX::RegPredRegisterClass] = RegisterList(); + usedRegs[PTX::RegI16RegisterClass] = RegisterList(); + usedRegs[PTX::RegI32RegisterClass] = RegisterList(); + usedRegs[PTX::RegI64RegisterClass] = RegisterList(); + usedRegs[PTX::RegF32RegisterClass] = RegisterList(); + usedRegs[PTX::RegF64RegisterClass] = RegisterList(); } void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; } @@ -94,6 +111,42 @@ public: return std::find(reg_local_var.begin(), reg_local_var.end(), reg) != reg_local_var.end(); } + + void addVirtualRegister(const TargetRegisterClass *TRC, unsigned Reg) { + usedRegs[TRC].push_back(Reg); + + std::string name; + + if (TRC == PTX::RegPredRegisterClass) + name = "%p"; + else if (TRC == PTX::RegI16RegisterClass) + name = "%rh"; + else if (TRC == PTX::RegI32RegisterClass) + name = "%r"; + else if (TRC == PTX::RegI64RegisterClass) + name = "%rd"; + else if (TRC == PTX::RegF32RegisterClass) + name = "%f"; + else if (TRC == PTX::RegF64RegisterClass) + name = "%fd"; + else + llvm_unreachable("Invalid register class"); + + name += utostr(usedRegs[TRC].size() - 1); + regNames[Reg] = name; + } + + std::string getRegisterName(unsigned Reg) const { + if (regNames.count(Reg)) + return regNames.lookup(Reg); + else + llvm_unreachable("Register not in register name map"); + } + + unsigned getNumRegistersForClass(const TargetRegisterClass *TRC) const { + return usedRegs.lookup(TRC).size(); + } + }; // class PTXMachineFunctionInfo } // namespace llvm