TableGen: Reimplement !foreach using the resolving mechanism

Summary:
This changes the syntax of !foreach so that the first "parameter" is
a new syntactic variable: !foreach(x, lst, expr) will define the
variable x within the scope of expr, and evaluation of the !foreach
will substitute elements of the given list (or dag) for x in expr.

Aside from leading to a nicer syntax, this allows more complex
expressions where x is deeply nested, or even constant expressions
in which x does not occur at all.

!foreach is currently not actually used anywhere in trunk, but I
plan to use it in the AMDGPU backend. If out-of-tree targets are
using it, they can adjust to the new syntax very easily.

Change-Id: Ib966694d8ab6542279d6bc358b6f4d767945a805

Reviewers: arsenm, craig.topper, tra, MartinO

Subscribers: wdng, llvm-commits, tpr

Differential Revision: https://reviews.llvm.org/D43651

llvm-svn: 326705
This commit is contained in:
Nicolai Haehnle 2018-03-05 15:21:04 +00:00
parent 0b0eaf7ee2
commit 8ebf7e4dfa
9 changed files with 274 additions and 152 deletions

View File

@ -192,9 +192,9 @@ supported include:
for 'a' in 'c.' This operation is analogous to $(subst) in GNU make.
``!foreach(a, b, c)``
For each member of dag or list 'b' apply operator 'c.' 'a' is a dummy
variable that should be declared as a member variable of an instantiated
class. This operation is analogous to $(foreach) in GNU make.
For each member of dag or list 'b' apply operator 'c'. 'a' is the name
of a variable that will be substituted by members of 'b' in 'c'.
This operation is analogous to $(foreach) in GNU make.
``!head(a)``
The first element of list 'a.'

View File

@ -17,6 +17,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/SmallVector.h"
@ -1030,6 +1031,9 @@ public:
static FieldInit *get(Init *R, StringInit *FN);
Init *getRecord() const { return Rec; }
StringInit *getFieldName() const { return FieldName; }
Init *getBit(unsigned Bit) const override;
Init *resolveReferences(Resolver &R) const override;
@ -1632,6 +1636,26 @@ public:
virtual bool keepUnsetBits() const { return false; }
};
/// Resolve arbitrary mappings.
class MapResolver final : public Resolver {
struct MappedValue {
Init *V;
bool Resolved;
MappedValue() : V(nullptr), Resolved(false) {}
MappedValue(Init *V, bool Resolved) : V(V), Resolved(Resolved) {}
};
DenseMap<Init *, MappedValue> Map;
public:
explicit MapResolver(Record *CurRec = nullptr) : Resolver(CurRec) {}
void set(Init *Key, Init *Value) { Map[Key] = {Value, false}; }
Init *resolve(Init *VarName) override;
};
/// Resolve all variables from a record except for unset variables.
class RecordResolver final : public Resolver {
DenseMap<Init *, Init *> Cache;
@ -1664,6 +1688,23 @@ public:
}
};
/// Delegate resolving to a sub-resolver, but shadow some variable names.
class ShadowResolver final : public Resolver {
Resolver &R;
DenseSet<Init *> Shadowed;
public:
explicit ShadowResolver(Resolver &R) : Resolver(R.getCurrentRecord()), R(R) {}
void addShadow(Init *Key) { Shadowed.insert(Key); }
Init *resolve(Init *VarName) override {
if (Shadowed.count(VarName))
return nullptr;
return R.resolve(VarName);
}
};
} // end namespace llvm
#endif // LLVM_TABLEGEN_RECORD_H

View File

@ -912,83 +912,57 @@ void TernOpInit::Profile(FoldingSetNodeID &ID) const {
ProfileTernOpInit(ID, getOpcode(), getLHS(), getMHS(), getRHS(), getType());
}
// Evaluates operation RHSo after replacing all operands matching LHS with Arg.
static Init *EvaluateOperation(OpInit *RHSo, Init *LHS, Init *Arg,
Record *CurRec, MultiClass *CurMultiClass) {
static Init *ForeachApply(Init *LHS, Init *MHSe, Init *RHS, Record *CurRec) {
MapResolver R(CurRec);
R.set(LHS, MHSe);
return RHS->resolveReferences(R);
}
SmallVector<Init *, 8> NewOperands;
NewOperands.reserve(RHSo->getNumOperands());
for (unsigned i = 0, e = RHSo->getNumOperands(); i < e; ++i) {
if (auto *RHSoo = dyn_cast<OpInit>(RHSo->getOperand(i))) {
if (Init *Result =
EvaluateOperation(RHSoo, LHS, Arg, CurRec, CurMultiClass))
NewOperands.push_back(Result);
else
NewOperands.push_back(RHSoo);
} else if (LHS->getAsString() == RHSo->getOperand(i)->getAsString()) {
NewOperands.push_back(Arg);
} else {
NewOperands.push_back(RHSo->getOperand(i));
}
static Init *ForeachDagApply(Init *LHS, DagInit *MHSd, Init *RHS,
Record *CurRec) {
bool Change = false;
Init *Val = ForeachApply(LHS, MHSd->getOperator(), RHS, CurRec);
if (Val != MHSd->getOperator())
Change = true;
SmallVector<std::pair<Init *, StringInit *>, 8> NewArgs;
for (unsigned int i = 0; i < MHSd->getNumArgs(); ++i) {
Init *Arg = MHSd->getArg(i);
Init *NewArg;
StringInit *ArgName = MHSd->getArgName(i);
if (DagInit *Argd = dyn_cast<DagInit>(Arg))
NewArg = ForeachDagApply(LHS, Argd, RHS, CurRec);
else
NewArg = ForeachApply(LHS, Arg, RHS, CurRec);
NewArgs.push_back(std::make_pair(NewArg, ArgName));
if (Arg != NewArg)
Change = true;
}
// Now run the operator and use its result as the new leaf
const OpInit *NewOp = RHSo->clone(NewOperands);
Init *NewVal = NewOp->Fold(CurRec, CurMultiClass);
return (NewVal != NewOp) ? NewVal : nullptr;
if (Change)
return DagInit::get(Val, nullptr, NewArgs);
return MHSd;
}
// Applies RHS to all elements of MHS, using LHS as a temp variable.
static Init *ForeachHelper(Init *LHS, Init *MHS, Init *RHS, RecTy *Type,
Record *CurRec, MultiClass *CurMultiClass) {
OpInit *RHSo = dyn_cast<OpInit>(RHS);
Record *CurRec) {
if (DagInit *MHSd = dyn_cast<DagInit>(MHS))
return ForeachDagApply(LHS, MHSd, RHS, CurRec);
if (!RHSo)
PrintFatalError(CurRec->getLoc(), "!foreach requires an operator\n");
TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
if (!LHSt)
PrintFatalError(CurRec->getLoc(), "!foreach requires typed variable\n");
DagInit *MHSd = dyn_cast<DagInit>(MHS);
if (MHSd) {
Init *Val = MHSd->getOperator();
if (Init *Result = EvaluateOperation(RHSo, LHS, Val, CurRec, CurMultiClass))
Val = Result;
SmallVector<std::pair<Init *, StringInit *>, 8> args;
for (unsigned int i = 0; i < MHSd->getNumArgs(); ++i) {
Init *Arg = MHSd->getArg(i);
StringInit *ArgName = MHSd->getArgName(i);
// If this is a dag, recurse
if (isa<DagInit>(Arg)) {
if (Init *Result =
ForeachHelper(LHS, Arg, RHSo, Type, CurRec, CurMultiClass))
Arg = Result;
} else if (Init *Result =
EvaluateOperation(RHSo, LHS, Arg, CurRec, CurMultiClass)) {
Arg = Result;
}
// TODO: Process arg names
args.push_back(std::make_pair(Arg, ArgName));
}
return DagInit::get(Val, nullptr, args);
}
ListInit *MHSl = dyn_cast<ListInit>(MHS);
ListRecTy *ListType = dyn_cast<ListRecTy>(Type);
if (MHSl && ListType) {
if (ListInit *MHSl = dyn_cast<ListInit>(MHS)) {
SmallVector<Init *, 8> NewList(MHSl->begin(), MHSl->end());
for (Init *&Arg : NewList) {
if (Init *Result =
EvaluateOperation(RHSo, LHS, Arg, CurRec, CurMultiClass))
Arg = Result;
for (Init *&Item : NewList) {
Init *NewItem = ForeachApply(LHS, Item, RHS, CurRec);
if (NewItem != Item)
Item = NewItem;
}
return ListInit::get(NewList, ListType->getElementType());
return ListInit::get(NewList, cast<ListRecTy>(Type)->getElementType());
}
return nullptr;
}
@ -1038,8 +1012,7 @@ Init *TernOpInit::Fold(Record *CurRec, MultiClass *CurMultiClass) const {
}
case FOREACH: {
if (Init *Result =
ForeachHelper(LHS, MHS, RHS, getType(), CurRec, CurMultiClass))
if (Init *Result = ForeachHelper(LHS, MHS, RHS, getType(), CurRec))
return Result;
break;
}
@ -1081,7 +1054,15 @@ Init *TernOpInit::resolveReferences(Resolver &R) const {
}
Init *mhs = MHS->resolveReferences(R);
Init *rhs = RHS->resolveReferences(R);
Init *rhs;
if (getOpcode() == FOREACH) {
ShadowResolver SR(R);
SR.addShadow(lhs);
rhs = RHS->resolveReferences(SR);
} else {
rhs = RHS->resolveReferences(R);
}
if (LHS != lhs || MHS != mhs || RHS != rhs)
return (TernOpInit::get(getOpcode(), lhs, mhs, rhs, getType()))
@ -1821,6 +1802,24 @@ Init *llvm::QualifyName(Record &CurRec, MultiClass *CurMultiClass,
return NewName;
}
Init *MapResolver::resolve(Init *VarName) {
auto It = Map.find(VarName);
if (It == Map.end())
return nullptr;
Init *I = It->second.V;
if (!It->second.Resolved && Map.size() > 1) {
// Resolve mutual references among the mapped variables, but prevent
// infinite recursion.
Map.erase(It);
I = I->resolveReferences(*this);
Map[VarName] = {I, true};
}
return I;
}
Init *RecordResolver::resolve(Init *VarName) {
Init *Val = Cache.lookup(VarName);
if (Val)

View File

@ -985,8 +985,109 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
return nullptr;
}
case tgtok::XForEach: { // Value ::= !foreach '(' Id ',' Value ',' Value ')'
SMLoc OpLoc = Lex.getLoc();
Lex.Lex(); // eat the operation
if (Lex.getCode() != tgtok::l_paren) {
TokError("expected '(' after !foreach");
return nullptr;
}
if (Lex.Lex() != tgtok::Id) { // eat the '('
TokError("first argument of !foreach must be an identifier");
return nullptr;
}
Init *LHS = StringInit::get(Lex.getCurStrVal());
if (CurRec->getValue(LHS)) {
TokError((Twine("iteration variable '") + LHS->getAsString() +
"' already defined")
.str());
return nullptr;
}
if (Lex.Lex() != tgtok::comma) { // eat the id
TokError("expected ',' in ternary operator");
return nullptr;
}
Lex.Lex(); // eat the ','
Init *MHS = ParseValue(CurRec);
if (!MHS)
return nullptr;
if (Lex.getCode() != tgtok::comma) {
TokError("expected ',' in ternary operator");
return nullptr;
}
Lex.Lex(); // eat the ','
TypedInit *MHSt = dyn_cast<TypedInit>(MHS);
if (!MHSt) {
TokError("could not get type of !foreach input");
return nullptr;
}
RecTy *InEltType = nullptr;
RecTy *OutEltType = nullptr;
bool IsDAG = false;
if (ListRecTy *InListTy = dyn_cast<ListRecTy>(MHSt->getType())) {
InEltType = InListTy->getElementType();
if (ItemType) {
if (ListRecTy *OutListTy = dyn_cast<ListRecTy>(ItemType)) {
OutEltType = OutListTy->getElementType();
} else {
Error(OpLoc,
"expected value of type '" + Twine(ItemType->getAsString()) +
"', but got !foreach of list type");
return nullptr;
}
}
} else if (DagRecTy *InDagTy = dyn_cast<DagRecTy>(MHSt->getType())) {
InEltType = InDagTy;
if (ItemType && !isa<DagRecTy>(ItemType)) {
Error(OpLoc,
"expected value of type '" + Twine(ItemType->getAsString()) +
"', but got !foreach of dag type");
return nullptr;
}
IsDAG = true;
} else {
TokError("!foreach must have list or dag input");
return nullptr;
}
CurRec->addValue(RecordVal(LHS, InEltType, false));
Init *RHS = ParseValue(CurRec, OutEltType);
CurRec->removeValue(LHS);
if (!RHS)
return nullptr;
if (Lex.getCode() != tgtok::r_paren) {
TokError("expected ')' in binary operator");
return nullptr;
}
Lex.Lex(); // eat the ')'
RecTy *OutType;
if (IsDAG) {
OutType = InEltType;
} else {
TypedInit *RHSt = dyn_cast<TypedInit>(RHS);
if (!RHSt) {
TokError("could not get type of !foreach result");
return nullptr;
}
OutType = RHSt->getType()->getListTy();
}
return (TernOpInit::get(TernOpInit::FOREACH, LHS, MHS, RHS, OutType))
->Fold(CurRec, CurMultiClass);
}
case tgtok::XIf:
case tgtok::XForEach:
case tgtok::XSubst: { // Value ::= !ternop '(' Value ',' Value ',' Value ')'
TernOpInit::TernaryOp Code;
RecTy *Type = nullptr;
@ -998,9 +1099,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XIf:
Code = TernOpInit::IF;
break;
case tgtok::XForEach:
Code = TernOpInit::FOREACH;
break;
case tgtok::XSubst:
Code = TernOpInit::SUBST;
break;
@ -1081,23 +1179,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
}
break;
}
case tgtok::XForEach: {
TypedInit *MHSt = dyn_cast<TypedInit>(MHS);
if (!MHSt) {
TokError("could not get type for !foreach");
return nullptr;
}
Type = MHSt->getType();
if (isa<ListRecTy>(Type)) {
TypedInit *RHSt = dyn_cast<TypedInit>(RHS);
if (!RHSt) {
TokError("could not get type of !foreach list elements");
return nullptr;
}
Type = RHSt->getType()->getListTy();
}
break;
}
case tgtok::XSubst: {
TypedInit *RHSt = dyn_cast<TypedInit>(RHS);
if (!RHSt) {

View File

@ -73,14 +73,6 @@ def VR128 : RegisterClass<[v2i64, v2f64],
def REGCLASS : RegisterClass<[], []>;
def MNEMONIC;
class decls {
// Dummy for foreach
dag pattern;
int operand;
}
def Decls : decls;
// Define intrinsics
def int_x86_sse2_add_ps : Intrinsic<"addps">;
def int_x86_sse2_add_pd : Intrinsic<"addpd">;
@ -92,16 +84,16 @@ class MakePat<list<dag> patterns> : Pat<patterns[0], patterns[1]>;
class Base<bits<8> opcode, dag opnds, dag iopnds, string asmstr, Intrinsic intr,
list<list<dag>> patterns>
: Inst<opcode, opnds, iopnds, asmstr,
!foreach(Decls.pattern, patterns[0],
!foreach(Decls.operand, Decls.pattern,
!foreach(pattern, patterns[0],
!foreach(operand, pattern,
!subst(INTRINSIC, intr,
!subst(REGCLASS, VR128,
!subst(MNEMONIC, set, Decls.operand)))))>,
MakePat<!foreach(Decls.pattern, patterns[1],
!foreach(Decls.operand, Decls.pattern,
!subst(MNEMONIC, set, operand)))))>,
MakePat<!foreach(pattern, patterns[1],
!foreach(operand, pattern,
!subst(INTRINSIC, intr,
!subst(REGCLASS, VR128,
!subst(MNEMONIC, set, Decls.operand)))))>;
!subst(MNEMONIC, set, operand)))))>;
multiclass arith<bits<8> opcode, string asmstr, string intr, list<list<dag>> patterns> {
def PS : Base<opcode, (outs VR128:$dst), (ins VR128:$src1, VR128:$src2),

View File

@ -69,14 +69,6 @@ def VR128 : RegisterClass<[v2i64, v2f64],
// Dummy for subst
def REGCLASS : RegisterClass<[], []>;
class decls {
// Dummy for foreach
dag pattern;
int operand;
}
def Decls : decls;
// Define intrinsics
def int_x86_sse2_add_ps : Intrinsic<"addps">;
def int_x86_sse2_add_pd : Intrinsic<"addpd">;
@ -85,17 +77,17 @@ def INTRINSIC : Intrinsic<"Dummy">;
multiclass arith<bits<8> opcode, string asmstr, string intr, list<dag> patterns> {
def PS : Inst<opcode, (outs VR128:$dst), (ins VR128:$src1, VR128:$src2),
!strconcat(asmstr, "\t$dst, $src1, $src2"),
!foreach(Decls.pattern, patterns,
!foreach(Decls.operand, Decls.pattern,
!foreach(pattern, patterns,
!foreach(operand, pattern,
!subst(INTRINSIC, !cast<Intrinsic>(!subst("SUFFIX", "_ps", intr)),
!subst(REGCLASS, VR128, Decls.operand))))>;
!subst(REGCLASS, VR128, operand))))>;
def PD : Inst<opcode, (outs VR128:$dst), (ins VR128:$src1, VR128:$src2),
!strconcat(asmstr, "\t$dst, $src1, $src2"),
!foreach(Decls.pattern, patterns,
!foreach(Decls.operand, Decls.pattern,
!foreach(pattern, patterns,
!foreach(operand, pattern,
!subst(INTRINSIC, !cast<Intrinsic>(!subst("SUFFIX", "_pd", intr)),
!subst(REGCLASS, VR128, Decls.operand))))>;
!subst(REGCLASS, VR128, operand))))>;
}
defm ADD : arith<0x58, "add", "int_x86_sse2_addSUFFIX",

View File

@ -10,11 +10,9 @@ def d3;
def d4;
class D<dag d> {
int tmp;
dag r1 = !foreach(tmp, d, !subst(d1, d0, !subst(d2, d0,
!subst(d3, d0,
!subst(d4, d0, tmp)))));
dag tmp2;
list<dag> dl = [d];
list<dag> r2 = !foreach(tmp2, dl,
!foreach(tmp, tmp2, !subst(d1, d0,
@ -29,10 +27,8 @@ class D<dag d> {
def d : D <(d0 d1, d2, d3, d4)>;
class I<list<int> i> {
int tmp;
list<int> r1 = !foreach(tmp, i, !add(3, !add(4, tmp)));
list<int> tmp2;
list<list<int>> li = [i];
list<list<int>> r2 = !foreach(tmp2, li,
!foreach(tmp, tmp2, !add(3, !add(4, tmp))));
@ -43,26 +39,20 @@ class I<list<int> i> {
// CHECK: list<list<int>> r2 = [{{[[]}}8, 9, 10]];
def i : I<[1,2,3]>;
class Tmp {
dag t0;
int t1;
}
def tmp: Tmp;
class J0<list<dag> pattern> {
list<dag> Pattern = pattern;
}
class J1<dag pattern>
: J0<[!foreach(tmp.t1, pattern, !subst(d1, d0,
!subst(d2, d0,
!subst(d3, d0,
!subst(d4, d0, tmp.t1)))))]>;
: J0<[!foreach(tmp, pattern, !subst(d1, d0,
!subst(d2, d0,
!subst(d3, d0,
!subst(d4, d0, tmp)))))]>;
class J2<list<dag> patterns>
: J0<!foreach(tmp.t0, patterns,
!foreach(tmp.t1, tmp.t0, !subst(d1, d0,
!subst(d2, d0,
!subst(d3, d0,
!subst(d4, d0, tmp.t1))))))>;
: J0<!foreach(t0, patterns,
!foreach(t1, t0, !subst(d1, d0,
!subst(d2, d0,
!subst(d3, d0,
!subst(d4, d0, t1))))))>;
// CHECK-LABEL: def j1
// CHECK: list<dag> Pattern = [(d0 d0:$dst, (d0 d0:$src1))];
def j1 : J1< (d1 d2:$dst, (d3 d4:$src1))>;

View File

@ -0,0 +1,25 @@
// RUN: llvm-tblgen %s | FileCheck %s
// XFAIL: vg_leak
// CHECK: --- Defs ---
// CHECK: def C0 {
// CHECK: list<list<int>> ret = {{\[}}[1, 2, 3], [1, 2, 3]];
// CHECK: }
// The variable name 'a' is used both in the "inner" and in the "outer" foreach.
// The test ensure that the inner declaration of 'a' properly shadows the outer
// one.
class A<list<int> lst> {
list<int> ret = !foreach(a, lst, !add(a, 1));
}
class B<list<int> lst1, list<int> lst2> {
list<list<int>> ret = !foreach(a, lst1, A<lst2>.ret);
}
class C<list<int> lst2> {
list<list<int>> ret = B<[0, 1], lst2>.ret;
}
def C0 : C<[0, 1, 2]>;

View File

@ -12,26 +12,22 @@
// CHECK: list<string> x = ["0", "1", "2"];
// CHECK: }
// CHECK: def DY {
// CHECK: list<int> y = [5, 7];
// CHECK: }
// CHECK: Jr
// CHECK: Sr
// Variables for foreach
class decls {
string name;
int num;
}
def Decls : decls;
class A<list<string> names> {
list<string> Names = names;
}
class B<list<string> names> : A<!foreach(Decls.name, names, !strconcat(Decls.name, ", Sr."))>;
class B<list<string> names> : A<!foreach(name, names, !strconcat(name, ", Sr."))>;
class C<list<string> names> : A<!foreach(Decls.name, names, !strconcat(Decls.name, ", Jr."))>;
class C<list<string> names> : A<!foreach(name, names, !strconcat(name, ", Jr."))>;
class D<list<string> names> : A<!foreach(Decls.name, names, !subst("NAME", "John Smith", Decls.name))>;
class D<list<string> names> : A<!foreach(name, names, !subst("NAME", "John Smith", name))>;
class Names {
list<string> values = ["Ken Griffey", "Seymour Cray"];
@ -45,7 +41,13 @@ def Smiths : D<["NAME", "Jane Smith"]>;
def Unprocessed : D<People.values>;
class X<list<int> a> {
list<string> x = !foreach(Decls.num, a, !cast<string>(Decls.num));
list<string> x = !foreach(num, a, !cast<string>(num));
}
def DX : X<[0, 1, 2]>;
class Y<list<int> a> {
list<int> y = !foreach(num, a, !add(!add(4, num), !add(1, num)));
}
def DY: Y<[0, 1]>;