diff --git a/llvm/lib/Target/X86/X86EvexToVex.cpp b/llvm/lib/Target/X86/X86EvexToVex.cpp index 97f843fa24eb..c7a013a0b17a 100644 --- a/llvm/lib/Target/X86/X86EvexToVex.cpp +++ b/llvm/lib/Target/X86/X86EvexToVex.cpp @@ -151,24 +151,6 @@ static bool performCustomAdjustments(MachineInstr &MI, unsigned NewOpc, (void)NewOpc; unsigned Opc = MI.getOpcode(); switch (Opc) { - case X86::VPDPBUSDSZ256m: - case X86::VPDPBUSDSZ256r: - case X86::VPDPBUSDSZ128m: - case X86::VPDPBUSDSZ128r: - case X86::VPDPBUSDZ256m: - case X86::VPDPBUSDZ256r: - case X86::VPDPBUSDZ128m: - case X86::VPDPBUSDZ128r: - case X86::VPDPWSSDSZ256m: - case X86::VPDPWSSDSZ256r: - case X86::VPDPWSSDSZ128m: - case X86::VPDPWSSDSZ128r: - case X86::VPDPWSSDZ256m: - case X86::VPDPWSSDZ256r: - case X86::VPDPWSSDZ128m: - case X86::VPDPWSSDZ128r: - // These can only VEX convert if AVXVNNI is enabled. - return ST->hasAVXVNNI(); case X86::VALIGNDZ128rri: case X86::VALIGNDZ128rmi: case X86::VALIGNQZ128rri: @@ -280,6 +262,9 @@ bool EvexToVexInstPass::CompressEvexToVexImpl(MachineInstr &MI) const { if (usesExtendedRegister(MI)) return false; + if (!CheckVEXInstPredicate(MI, ST)) + return false; + if (!performCustomAdjustments(MI, NewOpc, ST)) return false; diff --git a/llvm/lib/Target/X86/X86InstrFormats.td b/llvm/lib/Target/X86/X86InstrFormats.td index 686b19fc0a6c..dba13720cbd2 100644 --- a/llvm/lib/Target/X86/X86InstrFormats.td +++ b/llvm/lib/Target/X86/X86InstrFormats.td @@ -352,7 +352,8 @@ class X86Inst opcod, Format f, ImmType i, dag outs, dag ins, bit isMemoryFoldable = 1; // Is it allowed to memory fold/unfold this instruction? bit notEVEX2VEXConvertible = 0; // Prevent EVEX->VEX conversion. bit ExplicitVEXPrefix = 0; // Force the instruction to use VEX encoding. - + // Force to check predicate before compress EVEX to VEX encoding. + bit checkVEXPredicate = 0; // TSFlags layout should be kept in sync with X86BaseInfo.h. let TSFlags{6-0} = FormBits; let TSFlags{8-7} = OpSizeBits; diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td index ab84bae11583..662fa65c9566 100644 --- a/llvm/lib/Target/X86/X86InstrSSE.td +++ b/llvm/lib/Target/X86/X86InstrSSE.td @@ -7166,7 +7166,8 @@ defm VMASKMOVPD : avx_movmask_rm<0x2D, 0x2F, "vmaskmovpd", //===----------------------------------------------------------------------===// // AVX_VNNI //===----------------------------------------------------------------------===// -let Predicates = [HasAVXVNNI, NoVLX_Or_NoVNNI], Constraints = "$src1 = $dst" in +let Predicates = [HasAVXVNNI, NoVLX_Or_NoVNNI], Constraints = "$src1 = $dst", + ExplicitVEXPrefix = 1, checkVEXPredicate = 1 in multiclass avx_vnni_rm opc, string OpcodeStr, SDNode OpNode, bit IsCommutable> { let isCommutable = IsCommutable in @@ -7200,10 +7201,10 @@ multiclass avx_vnni_rm opc, string OpcodeStr, SDNode OpNode, VEX_4V, VEX_L, Sched<[SchedWriteVecIMul.XMM]>; } -defm VPDPBUSD : avx_vnni_rm<0x50, "vpdpbusd", X86Vpdpbusd, 0>, ExplicitVEXPrefix; -defm VPDPBUSDS : avx_vnni_rm<0x51, "vpdpbusds", X86Vpdpbusds, 0>, ExplicitVEXPrefix; -defm VPDPWSSD : avx_vnni_rm<0x52, "vpdpwssd", X86Vpdpwssd, 1>, ExplicitVEXPrefix; -defm VPDPWSSDS : avx_vnni_rm<0x53, "vpdpwssds", X86Vpdpwssds, 1>, ExplicitVEXPrefix; +defm VPDPBUSD : avx_vnni_rm<0x50, "vpdpbusd", X86Vpdpbusd, 0>; +defm VPDPBUSDS : avx_vnni_rm<0x51, "vpdpbusds", X86Vpdpbusds, 0>; +defm VPDPWSSD : avx_vnni_rm<0x52, "vpdpwssd", X86Vpdpwssd, 1>; +defm VPDPWSSDS : avx_vnni_rm<0x53, "vpdpwssds", X86Vpdpwssds, 1>; def X86vpmaddwd_su : PatFrag<(ops node:$lhs, node:$rhs), (X86vpmaddwd node:$lhs, node:$rhs), [{ diff --git a/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp b/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp index 6dc7e31e0dab..009dc036cf97 100644 --- a/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp +++ b/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp @@ -30,10 +30,13 @@ class X86EVEX2VEXTablesEmitter { std::map> VEXInsts; typedef std::pair Entry; + typedef std::pair Predicate; // Represent both compress tables std::vector EVEX2VEX128; std::vector EVEX2VEX256; + // Represent predicates of VEX instructions. + std::vector EVEX2VEXPredicates; public: X86EVEX2VEXTablesEmitter(RecordKeeper &R) : Records(R), Target(R) {} @@ -45,6 +48,9 @@ private: // Prints the given table as a C++ array of type // X86EvexToVexCompressTableEntry void printTable(const std::vector &Table, raw_ostream &OS); + // Prints function which checks target feature specific predicate. + void printCheckPredicate(const std::vector &Predicates, + raw_ostream &OS); }; void X86EVEX2VEXTablesEmitter::printTable(const std::vector &Table, @@ -67,6 +73,19 @@ void X86EVEX2VEXTablesEmitter::printTable(const std::vector &Table, OS << "};\n\n"; } +void X86EVEX2VEXTablesEmitter::printCheckPredicate( + const std::vector &Predicates, raw_ostream &OS) { + OS << "static bool CheckVEXInstPredicate" + << "(MachineInstr &MI, const X86Subtarget *Subtarget) {\n" + << " unsigned Opc = MI.getOpcode();\n" + << " switch (Opc) {\n" + << " default: return true;\n"; + for (auto Pair : Predicates) + OS << " case X86::" << Pair.first << ": return " << Pair.second << ";\n"; + OS << " }\n" + << "}\n\n"; +} + // Return true if the 2 BitsInits are equal // Calculates the integer value residing BitsInit object static inline uint64_t getValueFromBitsInit(const BitsInit *B) { @@ -169,6 +188,18 @@ private: }; void X86EVEX2VEXTablesEmitter::run(raw_ostream &OS) { + auto getPredicates = [&](const CodeGenInstruction *Inst) { + std::vector PredicatesRecords = + Inst->TheDef->getValueAsListOfDefs("Predicates"); + // Currently we only do AVX related checks and assume each instruction + // has one and only one AVX related predicates. + for (unsigned i = 0, e = PredicatesRecords.size(); i != e; ++i) + if (PredicatesRecords[i]->getName().startswith("HasAVX")) + return PredicatesRecords[i]->getValueAsString("CondString"); + llvm_unreachable( + "Instruction with checkPredicate set must have one predicate!"); + }; + emitSourceFileHeader("X86 EVEX2VEX tables", OS); ArrayRef NumberedInstructions = @@ -222,11 +253,18 @@ void X86EVEX2VEXTablesEmitter::run(raw_ostream &OS) { EVEX2VEX256.push_back(std::make_pair(EVEXInst, VEXInst)); // {0,1} else EVEX2VEX128.push_back(std::make_pair(EVEXInst, VEXInst)); // {0,0} + + // Adding predicate check to EVEX2VEXPredicates table when needed. + if (VEXInst->TheDef->getValueAsBit("checkVEXPredicate")) + EVEX2VEXPredicates.push_back( + std::make_pair(EVEXInst->TheDef->getName(), getPredicates(VEXInst))); } // Print both tables printTable(EVEX2VEX128, OS); printTable(EVEX2VEX256, OS); + // Print CheckVEXInstPredicate function. + printCheckPredicate(EVEX2VEXPredicates, OS); } }