[TableGen] [AMDGPU] Add !sub operator for subtraction

Use it in the AMDGPU target to eliminate !add(value1, !mul(value2, -1))

Differential Revision: https://reviews.llvm.org/D90107
This commit is contained in:
Paul C. Anagnostopoulos 2020-10-23 17:02:39 -04:00
parent b9c21d43bb
commit 9d72065cf6
9 changed files with 81 additions and 53 deletions

View File

@ -214,8 +214,8 @@ TableGen provides "bang operators" that have a wide variety of uses:
: !getdagop !gt !head !if !isa : !getdagop !gt !head !if !isa
: !le !listconcat !listsplat !lt !mul : !le !listconcat !listsplat !lt !mul
: !ne !not !or !setdagop !shl : !ne !not !or !setdagop !shl
: !size !sra !srl !strconcat !subst : !size !sra !srl !strconcat !sub
: !tail !xor : !subst !tail !xor
The ``!cond`` operator has a slightly different The ``!cond`` operator has a slightly different
syntax compared to other bang operators, so it is defined separately: syntax compared to other bang operators, so it is defined separately:
@ -1663,6 +1663,9 @@ and non-0 as true.
are not strings, in which are not strings, in which
case an implicit ``!cast<string>`` is done on those operands. case an implicit ``!cast<string>`` is done on those operands.
``!sub(``\ *a*\ ``,`` *b*\ ``)``
This operator subtracts *b* from *a* and produces the arithmetic difference.
``!subst(``\ *target*\ ``,`` *repl*\ ``,`` *value*\ ``)`` ``!subst(``\ *target*\ ``,`` *repl*\ ``,`` *value*\ ``)``
This operator replaces all occurrences of the *target* in the *value* with This operator replaces all occurrences of the *target* in the *value* with
the *repl* and produces the resulting value. The *value* can the *repl* and produces the resulting value. The *value* can

View File

@ -808,7 +808,7 @@ public:
/// !op (X, Y) - Combine two inits. /// !op (X, Y) - Combine two inits.
class BinOpInit : public OpInit, public FoldingSetNode { class BinOpInit : public OpInit, public FoldingSetNode {
public: public:
enum BinaryOp : uint8_t { ADD, MUL, AND, OR, XOR, SHL, SRA, SRL, LISTCONCAT, enum BinaryOp : uint8_t { ADD, SUB, MUL, AND, OR, XOR, SHL, SRA, SRL, LISTCONCAT,
LISTSPLAT, STRCONCAT, CONCAT, EQ, NE, LE, LT, GE, LISTSPLAT, STRCONCAT, CONCAT, EQ, NE, LE, LT, GE,
GT, SETDAGOP }; GT, SETDAGOP };

View File

@ -1024,6 +1024,7 @@ Init *BinOpInit::Fold(Record *CurRec) const {
break; break;
} }
case ADD: case ADD:
case SUB:
case MUL: case MUL:
case AND: case AND:
case OR: case OR:
@ -1040,9 +1041,10 @@ Init *BinOpInit::Fold(Record *CurRec) const {
int64_t Result; int64_t Result;
switch (getOpcode()) { switch (getOpcode()) {
default: llvm_unreachable("Bad opcode!"); default: llvm_unreachable("Bad opcode!");
case ADD: Result = LHSv + RHSv; break; case ADD: Result = LHSv + RHSv; break;
case MUL: Result = LHSv * RHSv; break; case SUB: Result = LHSv - RHSv; break;
case AND: Result = LHSv & RHSv; break; case MUL: Result = LHSv * RHSv; break;
case AND: Result = LHSv & RHSv; break;
case OR: Result = LHSv | RHSv; break; case OR: Result = LHSv | RHSv; break;
case XOR: Result = LHSv ^ RHSv; break; case XOR: Result = LHSv ^ RHSv; break;
case SHL: Result = (uint64_t)LHSv << (uint64_t)RHSv; break; case SHL: Result = (uint64_t)LHSv << (uint64_t)RHSv; break;
@ -1072,6 +1074,7 @@ std::string BinOpInit::getAsString() const {
switch (getOpcode()) { switch (getOpcode()) {
case CONCAT: Result = "!con"; break; case CONCAT: Result = "!con"; break;
case ADD: Result = "!add"; break; case ADD: Result = "!add"; break;
case SUB: Result = "!sub"; break;
case MUL: Result = "!mul"; break; case MUL: Result = "!mul"; break;
case AND: Result = "!and"; break; case AND: Result = "!and"; break;
case OR: Result = "!or"; break; case OR: Result = "!or"; break;

View File

@ -562,6 +562,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
.Case("con", tgtok::XConcat) .Case("con", tgtok::XConcat)
.Case("dag", tgtok::XDag) .Case("dag", tgtok::XDag)
.Case("add", tgtok::XADD) .Case("add", tgtok::XADD)
.Case("sub", tgtok::XSUB)
.Case("mul", tgtok::XMUL) .Case("mul", tgtok::XMUL)
.Case("not", tgtok::XNOT) .Case("not", tgtok::XNOT)
.Case("and", tgtok::XAND) .Case("and", tgtok::XAND)

View File

@ -51,7 +51,7 @@ namespace tgtok {
MultiClass, String, Defset, Defvar, If, Then, ElseKW, MultiClass, String, Defset, Defvar, If, Then, ElseKW,
// !keywords. // !keywords.
XConcat, XADD, XMUL, XNOT, XAND, XOR, XXOR, XSRA, XSRL, XSHL, XConcat, XADD, XSUB, XMUL, XNOT, XAND, XOR, XXOR, XSRA, XSRL, XSHL,
XListConcat, XListSplat, XStrConcat, XCast, XSubst, XForEach, XFoldl, XListConcat, XListSplat, XStrConcat, XCast, XSubst, XForEach, XFoldl,
XHead, XTail, XSize, XEmpty, XIf, XCond, XEq, XIsA, XDag, XNe, XLe, XHead, XTail, XSize, XEmpty, XIf, XCond, XEq, XIsA, XDag, XNe, XLe,
XLt, XGe, XGt, XSetDagOp, XGetDagOp, XLt, XGe, XGt, XSetDagOp, XGetDagOp,

View File

@ -1075,6 +1075,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XConcat: case tgtok::XConcat:
case tgtok::XADD: case tgtok::XADD:
case tgtok::XSUB:
case tgtok::XMUL: case tgtok::XMUL:
case tgtok::XAND: case tgtok::XAND:
case tgtok::XOR: case tgtok::XOR:
@ -1101,6 +1102,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
default: llvm_unreachable("Unhandled code!"); default: llvm_unreachable("Unhandled code!");
case tgtok::XConcat: Code = BinOpInit::CONCAT; break; case tgtok::XConcat: Code = BinOpInit::CONCAT; break;
case tgtok::XADD: Code = BinOpInit::ADD; break; case tgtok::XADD: Code = BinOpInit::ADD; break;
case tgtok::XSUB: Code = BinOpInit::SUB; break;
case tgtok::XMUL: Code = BinOpInit::MUL; break; case tgtok::XMUL: Code = BinOpInit::MUL; break;
case tgtok::XAND: Code = BinOpInit::AND; break; case tgtok::XAND: Code = BinOpInit::AND; break;
case tgtok::XOR: Code = BinOpInit::OR; break; case tgtok::XOR: Code = BinOpInit::OR; break;
@ -1137,6 +1139,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XSRL: case tgtok::XSRL:
case tgtok::XSHL: case tgtok::XSHL:
case tgtok::XADD: case tgtok::XADD:
case tgtok::XSUB:
case tgtok::XMUL: case tgtok::XMUL:
Type = IntRecTy::get(); Type = IntRecTy::get();
ArgType = IntRecTy::get(); ArgType = IntRecTy::get();
@ -1249,10 +1252,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
ListType->getAsString() + "'"); ListType->getAsString() + "'");
return nullptr; return nullptr;
} }
if (Code != BinOpInit::ADD && Code != BinOpInit::AND && if (Code != BinOpInit::ADD && Code != BinOpInit::SUB &&
Code != BinOpInit::OR && Code != BinOpInit::XOR && Code != BinOpInit::AND && Code != BinOpInit::OR &&
Code != BinOpInit::SRA && Code != BinOpInit::SRL && Code != BinOpInit::XOR && Code != BinOpInit::SRA &&
Code != BinOpInit::SHL && Code != BinOpInit::MUL) Code != BinOpInit::SRL && Code != BinOpInit::SHL &&
Code != BinOpInit::MUL)
ArgType = Resolved; ArgType = Resolved;
} }
@ -1799,6 +1803,7 @@ Init *TGParser::ParseOperationCond(Record *CurRec, RecTy *ItemType) {
/// SimpleValue ::= '(' IDValue DagArgList ')' /// SimpleValue ::= '(' IDValue DagArgList ')'
/// SimpleValue ::= CONCATTOK '(' Value ',' Value ')' /// SimpleValue ::= CONCATTOK '(' Value ',' Value ')'
/// SimpleValue ::= ADDTOK '(' Value ',' Value ')' /// SimpleValue ::= ADDTOK '(' Value ',' Value ')'
/// SimpleValue ::= SUBTOK '(' Value ',' Value ')'
/// SimpleValue ::= SHLTOK '(' Value ',' Value ')' /// SimpleValue ::= SHLTOK '(' Value ',' Value ')'
/// SimpleValue ::= SRATOK '(' Value ',' Value ')' /// SimpleValue ::= SRATOK '(' Value ',' Value ')'
/// SimpleValue ::= SRLTOK '(' Value ',' Value ')' /// SimpleValue ::= SRLTOK '(' Value ',' Value ')'
@ -2094,6 +2099,7 @@ Init *TGParser::ParseSimpleValue(Record *CurRec, RecTy *ItemType,
case tgtok::XConcat: case tgtok::XConcat:
case tgtok::XDag: case tgtok::XDag:
case tgtok::XADD: case tgtok::XADD:
case tgtok::XSUB:
case tgtok::XMUL: case tgtok::XMUL:
case tgtok::XNOT: case tgtok::XNOT:
case tgtok::XAND: case tgtok::XAND:

View File

@ -27,17 +27,17 @@ let Namespace = "AMDGPU" in {
def lo16 : SubRegIndex<16, 0>; def lo16 : SubRegIndex<16, 0>;
def hi16 : SubRegIndex<16, 16>; def hi16 : SubRegIndex<16, 16>;
foreach Index = 0-31 in { foreach Index = 0...31 in {
def sub#Index : SubRegIndex<32, !shl(Index, 5)>; def sub#Index : SubRegIndex<32, !shl(Index, 5)>;
} }
foreach Index = 1-31 in { foreach Index = 1...31 in {
def sub#Index#_lo16 : ComposedSubRegIndex<!cast<SubRegIndex>(sub#Index), lo16>; def sub#Index#_lo16 : ComposedSubRegIndex<!cast<SubRegIndex>(sub#Index), lo16>;
def sub#Index#_hi16 : ComposedSubRegIndex<!cast<SubRegIndex>(sub#Index), hi16>; def sub#Index#_hi16 : ComposedSubRegIndex<!cast<SubRegIndex>(sub#Index), hi16>;
} }
foreach Size = {2-6,8,16} in { foreach Size = {2...6,8,16} in {
foreach Index = Indexes<!add(33, !mul(Size, -1))>.slice in { foreach Index = Indexes<!sub(33, Size)>.slice in {
def !foldl("", Indexes<Size>.slice, acc, cur, def !foldl("", Indexes<Size>.slice, acc, cur,
!strconcat(acc#!if(!eq(acc,""),"","_"), "sub"#!add(cur, Index))) : !strconcat(acc#!if(!eq(acc,""),"","_"), "sub"#!add(cur, Index))) :
SubRegIndex<!mul(Size, 32), !shl(Index, 5)> { SubRegIndex<!mul(Size, 32), !shl(Index, 5)> {
@ -89,7 +89,7 @@ class getSubRegs<int size> {
class RegSeqNames<int last_reg, int stride, int size, string prefix, class RegSeqNames<int last_reg, int stride, int size, string prefix,
int start = 0> { int start = 0> {
int next = !add(start, stride); int next = !add(start, stride);
int end_reg = !add(!add(start, size), -1); int end_reg = !add(start, size, -1);
list<string> ret = list<string> ret =
!if(!le(end_reg, last_reg), !if(!le(end_reg, last_reg),
!listconcat([prefix # "[" # start # ":" # end_reg # "]"], !listconcat([prefix # "[" # start # ":" # end_reg # "]"],
@ -102,7 +102,7 @@ class RegSeqDags<RegisterClass RC, int last_reg, int stride, int size,
int start = 0> { int start = 0> {
dag trunc_rc = (trunc RC, dag trunc_rc = (trunc RC,
!if(!and(!eq(stride, 1), !eq(start, 0)), !if(!and(!eq(stride, 1), !eq(start, 0)),
!add(!add(last_reg, 2), !mul(size, -1)), !sub(!add(last_reg, 2), size),
!add(last_reg, 1))); !add(last_reg, 1)));
list<dag> ret = list<dag> ret =
!if(!lt(start, size), !if(!lt(start, size),
@ -247,7 +247,7 @@ def TMA : RegisterWithSubRegs<"tma", [TMA_LO, TMA_HI]> {
let HWEncoding = 110; let HWEncoding = 110;
} }
foreach Index = 0-15 in { foreach Index = 0...15 in {
defm TTMP#Index#_vi : SIRegLoHi16<"ttmp"#Index, !add(112, Index)>; defm TTMP#Index#_vi : SIRegLoHi16<"ttmp"#Index, !add(112, Index)>;
defm TTMP#Index#_gfx9_gfx10 : SIRegLoHi16<"ttmp"#Index, !add(108, Index)>; defm TTMP#Index#_gfx9_gfx10 : SIRegLoHi16<"ttmp"#Index, !add(108, Index)>;
defm TTMP#Index : SIRegLoHi16<"ttmp"#Index, 0>; defm TTMP#Index : SIRegLoHi16<"ttmp"#Index, 0>;
@ -274,7 +274,7 @@ def FLAT_SCR_vi : FlatReg<FLAT_SCR_LO_vi, FLAT_SCR_HI_vi, 102>;
def FLAT_SCR : FlatReg<FLAT_SCR_LO, FLAT_SCR_HI, 0>; def FLAT_SCR : FlatReg<FLAT_SCR_LO, FLAT_SCR_HI, 0>;
// SGPR registers // SGPR registers
foreach Index = 0-105 in { foreach Index = 0...105 in {
defm SGPR#Index : defm SGPR#Index :
SIRegLoHi16 <"s"#Index, Index>, SIRegLoHi16 <"s"#Index, Index>,
DwarfRegNum<[!if(!le(Index, 63), !add(Index, 32), !add(Index, 1024)), DwarfRegNum<[!if(!le(Index, 63), !add(Index, 32), !add(Index, 1024)),
@ -282,14 +282,14 @@ foreach Index = 0-105 in {
} }
// VGPR registers // VGPR registers
foreach Index = 0-255 in { foreach Index = 0...255 in {
defm VGPR#Index : defm VGPR#Index :
SIRegLoHi16 <"v"#Index, Index, 0, 1>, SIRegLoHi16 <"v"#Index, Index, 0, 1>,
DwarfRegNum<[!add(Index, 2560), !add(Index, 1536)]>; DwarfRegNum<[!add(Index, 2560), !add(Index, 1536)]>;
} }
// AccVGPR registers // AccVGPR registers
foreach Index = 0-255 in { foreach Index = 0...255 in {
defm AGPR#Index : defm AGPR#Index :
SIRegLoHi16 <"a"#Index, Index, 1, 1>, SIRegLoHi16 <"a"#Index, Index, 1, 1>,
DwarfRegNum<[!add(Index, 3072), !add(Index, 2048)]>; DwarfRegNum<[!add(Index, 3072), !add(Index, 2048)]>;
@ -389,7 +389,7 @@ def TTMP_512Regs : SIRegisterTuples<getSubRegs<16>.ret, TTMP_32, 15, 4, 16, "ttm
class TmpRegTuplesBase<int index, int size, class TmpRegTuplesBase<int index, int size,
list<Register> subRegs, list<Register> subRegs,
list<SubRegIndex> indices = getSubRegs<size>.ret, list<SubRegIndex> indices = getSubRegs<size>.ret,
int index1 = !add(index, !add(size, -1)), int index1 = !add(index, size, -1),
string name = "ttmp["#index#":"#index1#"]"> : string name = "ttmp["#index#":"#index1#"]"> :
RegisterWithSubRegs<name, subRegs> { RegisterWithSubRegs<name, subRegs> {
let HWEncoding = subRegs[0].HWEncoding; let HWEncoding = subRegs[0].HWEncoding;

View File

@ -4,24 +4,32 @@
// CHECK: --- Defs --- // CHECK: --- Defs ---
// CHECK: def A0 { // CHECK: def A0 {
// CHECK: bits<8> add = { 0, 1, 0, 0, 0, 0, 0, 0 }; // CHECK: bits<8> add = { 0, 0, 0, 1, 1, 0, 0, 0 };
// CHECK: bits<8> sub = { 0, 0, 0, 1, 0, 0, 1, 0 };
// CHECK: bits<8> and = { 0, 0, 0, 0, 0, 0, 0, 1 }; // CHECK: bits<8> and = { 0, 0, 0, 0, 0, 0, 0, 1 };
// CHECK: bits<8> or = { 0, 0, 1, 1, 1, 1, 1, 1 }; // CHECK: bits<8> or = { 0, 0, 0, 1, 0, 1, 1, 1 };
// CHECK: bits<8> xor = { 0, 0, 1, 1, 1, 1, 1, 0 }; // CHECK: bits<8> xor = { 0, 0, 0, 1, 0, 1, 1, 0 };
// CHECK: bits<8> srl = { 0, 0, 0, 1, 1, 1, 1, 1 }; // CHECK: bits<8> srl = { 0, 0, 0, 0, 0, 0, 1, 0 };
// CHECK: bits<8> sra = { 0, 0, 0, 1, 1, 1, 1, 1 }; // CHECK: bits<8> sra = { 0, 0, 0, 0, 0, 0, 1, 0 };
// CHECK: bits<8> shl = { 0, 1, 1, 1, 1, 1, 1, 0 }; // CHECK: bits<8> shl = { 1, 0, 1, 0, 1, 0, 0, 0 };
// CHECK: }
// CHECK: bits<8> sra = { 1, 1, 1, 1, 1, 1, 1, 1 };
class A<bits<8> a, bits<2> b> { class A<bits<8> a, bits<2> b> {
// Operands of different bits types are allowed. // Operands of different bits types are allowed.
bits<8> add = !add(a, b); bits<8> add = !add(a, b);
bits<8> sub = !sub(a, b);
bits<8> and = !and(a, b); bits<8> and = !and(a, b);
bits<8> or = !or(a, b); bits<8> or = !or(a, b);
bits<8> xor = !xor(a, b); bits<8> xor = !xor(a, b);
bits<8> srl = !srl(a, b); bits<8> srl = !srl(a, b);
bits<8> sra = !sra(a, b); bits<8> sra = !sra(a, b);
bits<8> shl = !shl(a, b); bits<8> shl = !shl(a, b);
} }
def A0 : A<63, 1>; def A0 : A<21, 3>;
def A1 {
bits<8> sra = !sra(-1, 3);
}

View File

@ -1,37 +1,45 @@
// RUN: llvm-tblgen %s | FileCheck %s // RUN: llvm-tblgen %s | FileCheck %s
// XFAIL: vg_leak // XFAIL: vg_leak
// CHECK: def shifts
// CHECK: shifted_b = 8
// CHECK: shifted_i = 8
def shifts { def shifts {
bits<2> b = 0b10; bits<2> b = 0b10;
int i = 2; int i = 2;
int shifted_b = !shl(b, 2); int shifted_b = !shl(b, 2);
int shifted_i = !shl(i, 2); int shifted_i = !shl(i, 2);
} }
// CHECK: def shifts
// CHECK: shifted_b = 8
// CHECK: shifted_i = 8
class Int<int value> { class Int<int value> {
int Value = value; int Value = value;
} }
def v1022 : Int<1022>; def int2 : Int<2>;
def int1022 : Int<1022>;
def int1024 : Int<1024>;
// CHECK: def v0 // CHECK: def v0a
// CHECK: Value = 0 // CHECK: Value = 0
def v0a : Int<!sub(int1024.Value, int1024.Value)>;
// CHECK: def v0b
// CHECK: Value = 0
def v0b : Int<!and(int1024.Value, 2048)>;
// CHECK: def v1 // CHECK: def v1
// CHECK: Value = 1 // CHECK: Value = 1
def v1 : Int<!and(1025, 1)>;
// CHECK: def v1019
// CHECK: Value = 1019
def v1019 : Int<!sub(int1022.Value, 3)>;
// CHECK: def v1023 // CHECK: def v1023
// CHECK: Value = 1023 // CHECK: Value = 1023
def v1023 : Int<!or(v1022.Value, 1)>; def v1023 : Int<!or(int1022.Value, 1)>;
def v1024 : Int<1024>; def v1025 : Int<!add(int1024.Value, 1)>;
// CHECK: def v1024
// CHECK: Value = 1024
def v1025 : Int<!add(v1024.Value, 1)>;
// CHECK: def v1025 // CHECK: def v1025
// CHECK: Value = 1025 // CHECK: Value = 1025
@ -42,20 +50,13 @@ def v12 : Int<!mul(4, 3)>;
// CHECK: def v1a // CHECK: def v1a
// CHECK: Value = 1 // CHECK: Value = 1
// CHECK: def v2
// CHECK: Value = 2
def v2 : Int<2>;
def v2048 : Int<!add(v1024.Value, v1024.Value)>;
// CHECK: def v2048 // CHECK: def v2048
// CHECK: Value = 2048 // CHECK: Value = 2048
def v2048 : Int<!add(int1024.Value, int1024.Value)>;
def v0 : Int<!and(v1024.Value, v2048.Value)>;
def v1 : Int<!and(v1025.Value, 1)>;
// CHECK: def v3072 // CHECK: def v3072
// CHECK: Value = 3072 // CHECK: Value = 3072
def v3072 : Int<!or(v1024.Value, v2048.Value)>; def v3072 : Int<!or(int1024.Value, v2048.Value)>;
// CHECK: def v4 // CHECK: def v4
// CHECK: Value = 4 // CHECK: Value = 4
@ -63,8 +64,8 @@ def v3072 : Int<!or(v1024.Value, v2048.Value)>;
// CHECK: def v7 // CHECK: def v7
// CHECK: Value = 7 // CHECK: Value = 7
def v4 : Int<!add(v2.Value, 1, v1.Value)>; def v4 : Int<!add(int2.Value, 1, v1.Value)>;
def v7 : Int<!or(v1.Value, v2.Value, v4.Value)>; def v7 : Int<!or(v1.Value, int2.Value, v4.Value)>;
def v1a : Int<!and(v7.Value, 5, v1.Value)>; def v1a : Int<!and(v7.Value, 5, v1.Value)>;
// CHECK: def v84 // CHECK: def v84
@ -79,4 +80,10 @@ def v9 : Int<!xor(v7.Value, 0x0E)>;
// CHECK: Value = 924 // CHECK: Value = 924
def v924 : Int<!mul(v84.Value, 11)>; def v924 : Int<!mul(v84.Value, 11)>;
// CHECK: def v925
// CHECK: Value = 925
def v925 : Int<!sub(v924.Value, -1)>;
// CHECK: def vneg
// CHECK: Value = -2
def vneg : Int<!sub(v925.Value, 927)>;