diff --git a/llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h index 76f11d3187e0..3a59e7f18b6a 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h @@ -18,7 +18,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/GlobalISel/RegisterBank.h" -#include "llvm/CodeGen/GlobalISel/Types.h" +#include "llvm/CodeGen/MachineValueType.h" // For SimpleValueType. #include "llvm/Support/ErrorHandling.h" #include @@ -158,6 +158,9 @@ protected: /// Total number of register banks. unsigned NumRegBanks; + /// Mapping from MVT::SimpleValueType to register banks. + std::unique_ptr VTToRegBank; + /// Create a RegisterBankInfo that can accomodate up to \p NumRegBanks /// RegisterBank instances. /// @@ -182,13 +185,21 @@ protected: /// \pre \p ID < NumRegBanks. void createRegisterBank(unsigned ID, const char *Name); - /// Add \p RCId to the set of register class that the register bank - /// identified \p ID covers. + /// Add \p RCId to the set of register class that the register bank, + /// identified \p ID, covers. /// This method transitively adds all the sub classes and the subreg-classes /// of \p RCId to the set of covered register classes. /// It also adjusts the size of the register bank to reflect the maximal /// size of a value that can be hold into that register bank. /// + /// If \p AddTypeMapping is true, this method also records what types can + /// be mapped to \p ID. Although this done by default, targets may want to + /// disable it, espicially if a given type may be mapped on different + /// register bank. Indeed, in such case, this method only records the + /// last register bank where the type matches. + /// This information is only used to provide default mapping + /// (see getInstrMappingImpl). + /// /// \note This method does *not* add the super classes of \p RCId. /// The rationale is if \p ID covers the registers of \p RCId, that /// does not necessarily mean that \p ID covers the set of registers @@ -199,7 +210,8 @@ protected: /// /// \todo TableGen should just generate the BitSet vector for us. void addRegBankCoverage(unsigned ID, unsigned RCId, - const TargetRegisterInfo &TRI); + const TargetRegisterInfo &TRI, + bool AddTypeMapping = true); /// Get the register bank identified by \p ID. RegisterBank &getRegBank(unsigned ID) { @@ -207,6 +219,28 @@ protected: return RegBanks[ID]; } + /// Get the register bank that has been recorded to cover \p SVT. + const RegisterBank *getRegBankForType(MVT::SimpleValueType SVT) const { + if (!VTToRegBank) + return nullptr; + assert(SVT < MVT::SimpleValueType::LAST_VALUETYPE && "Out-of-bound access"); + return VTToRegBank.get()[SVT]; + } + + /// Add \p SVT to the type that \p RegBank covers. + /// + /// \post If \p SVT was covered by another register bank before that call, + /// then this information is gone. + /// \post getRegBankForType(SVT) == &RegBank + void recordRegBankForType(const RegisterBank &RegBank, + MVT::SimpleValueType SVT) { + if (!VTToRegBank) + VTToRegBank.reset( + new const RegisterBank *[MVT::SimpleValueType::LAST_VALUETYPE]); + assert(SVT < MVT::SimpleValueType::LAST_VALUETYPE && "Out-of-bound access"); + VTToRegBank.get()[SVT] = &RegBank; + } + /// Try to get the mapping of \p MI. /// See getInstrMapping for more details on what a mapping represents. /// diff --git a/llvm/lib/CodeGen/GlobalISel/RegisterBankInfo.cpp b/llvm/lib/CodeGen/GlobalISel/RegisterBankInfo.cpp index 72064e1d6aab..f1d997eaa707 100644 --- a/llvm/lib/CodeGen/GlobalISel/RegisterBankInfo.cpp +++ b/llvm/lib/CodeGen/GlobalISel/RegisterBankInfo.cpp @@ -13,6 +13,7 @@ #include "llvm/CodeGen/GlobalISel/RegisterBank.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/CodeGen/GlobalISel/RegisterBankInfo.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFunction.h" @@ -86,7 +87,8 @@ void RegisterBankInfo::createRegisterBank(unsigned ID, const char *Name) { } void RegisterBankInfo::addRegBankCoverage(unsigned ID, unsigned RCId, - const TargetRegisterInfo &TRI) { + const TargetRegisterInfo &TRI, + bool AddTypeMapping) { RegisterBank &RB = getRegBank(ID); unsigned NbOfRegClasses = TRI.getNumRegClasses(); @@ -118,6 +120,13 @@ void RegisterBankInfo::addRegBankCoverage(unsigned ID, unsigned RCId, // Remember the biggest size in bits. MaxSize = std::max(MaxSize, CurRC.getSize() * 8); + // If we have been asked to record the type supported by this + // register bank, do it now. + if (AddTypeMapping) + for (MVT::SimpleValueType SVT : + make_range(CurRC.vt_begin(), CurRC.vt_end())) + recordRegBankForType(getRegBank(ID), SVT); + // Walk through all sub register classes and push them into the worklist. bool First = true; for (BitMaskClassIterator It(CurRC.getSubClassMask(), TRI); It.isValid();