forked from OSchip/llvm-project
[TableGen] Model variadic operands using Variadic<Type>
Previously, we were using the trait mechanism to specify that an op has variadic operands. That led a discrepancy between how we handle ops with deterministic number of operands. Besides, we have no way to specify the constraints and match against the variadic operands. This CL introduced Variadic<Type> as a way to solve the above issues. PiperOrigin-RevId: 232656104
This commit is contained in:
parent
0c65cf283c
commit
1df6ca5053
|
@ -100,10 +100,20 @@ class TypeConstraint<Pred condition, string descr> {
|
|||
string description = descr;
|
||||
}
|
||||
|
||||
// A type, carries type constraints, but accepts any type by default.
|
||||
// A type, carries type constraints.
|
||||
class Type<Pred condition, string descr = "">
|
||||
: TypeConstraint<condition, descr>;
|
||||
|
||||
// A variadic type. It expands to zero or more of the base type.
|
||||
// This class is used for supporting variadic operands/results. An op can
|
||||
// declare no more than one variadic operand/result, and that operand/result
|
||||
// must be the last one in the operand/result list.
|
||||
class Variadic<Type type, string descr = "">
|
||||
// TODO: support variadic type conditions
|
||||
: Type<CPred<"true">, descr> {
|
||||
Type baseType = type;
|
||||
}
|
||||
|
||||
// A type that can be constructed using MLIR::Builder.
|
||||
// Note that this does not "inherit" from Type because it would require
|
||||
// duplicating Type subclasses for buildable and non-buildable cases to avoid
|
||||
|
@ -352,11 +362,6 @@ class OpTrait<string prop> {
|
|||
string trait = prop;
|
||||
}
|
||||
|
||||
// Note: Ideally, we should be able to automatically deduce most of these traits
|
||||
// from other bits of op definitions, especially those regarding the number of
|
||||
// operands and results.
|
||||
class AtLeastNOperands<int c> : OpTrait<"AtLeastNOperands<" # c # ">::Impl">;
|
||||
|
||||
// op supports operand broadcast behavior
|
||||
def Broadcastable : OpTrait<"BroadcastableTwoOperandsOneResult">;
|
||||
// X op Y == Y op X
|
||||
|
@ -367,8 +372,6 @@ def NoSideEffect : OpTrait<"HasNoSideEffect">;
|
|||
def SameValueType : OpTrait<"SameOperandsAndResultType">;
|
||||
// op is a terminator
|
||||
def Terminator : OpTrait<"IsTerminator">;
|
||||
// op has an unknown number of operands
|
||||
def VariadicOperands : OpTrait<"VariadicOperands">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops
|
||||
|
|
|
@ -89,8 +89,8 @@ class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
// Base class for LLVM terminator operations. All terminator operations have
|
||||
// zero results and an optional list of successors.
|
||||
class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
LLVM_Op<mnemonic, !listconcat(traits, [Terminator, VariadicOperands])>,
|
||||
Results<(outs)> {
|
||||
LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>,
|
||||
Arguments<(ins Variadic<LLVM_Type>)>, Results<(outs)> {
|
||||
let builder = [{
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
ArrayRef<Value *> properOperands,
|
||||
|
@ -117,7 +117,7 @@ class LLVM_ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
// Class for variadic instructions.
|
||||
class LLVM_VariadicOneResultOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
LLVM_OneResultOp<mnemonic, !listconcat(traits, [VariadicOperands])>;
|
||||
LLVM_OneResultOp<mnemonic, traits>, Arguments<(ins Variadic<LLVM_Type>)>;
|
||||
|
||||
// Integer binary instructions.
|
||||
def LLVM_AddOp : LLVM_ArithmeticOp<"add", [Commutative]>;
|
||||
|
@ -151,7 +151,8 @@ def LLVM_BitcastOp : LLVM_OneResultOp<"bitcast", [NoSideEffect]>,
|
|||
|
||||
// Call-related instructions.
|
||||
def LLVM_CallOp : LLVM_VariadicOneResultOp<"call">;
|
||||
def LLVM_Call0Op : LLVM_ZeroResultOp<"call0", [VariadicOperands]>;
|
||||
def LLVM_Call0Op : LLVM_ZeroResultOp<"call0", []>,
|
||||
Arguments<(ins Variadic<LLVM_Type>)>;
|
||||
def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>,
|
||||
Arguments<(ins LLVM_Type)>;
|
||||
def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,
|
||||
|
|
|
@ -97,6 +97,9 @@ public:
|
|||
Operand &getOperand(int index) { return operands[index]; }
|
||||
const Operand &getOperand(int index) const { return operands[index]; }
|
||||
|
||||
// Returns true if this operation has a variadic operand.
|
||||
bool hasVariadicOperand() const;
|
||||
|
||||
// Op argument (attribute or operand) accessors.
|
||||
Argument getArg(int index);
|
||||
StringRef getArgName(int index) const;
|
||||
|
|
|
@ -71,6 +71,13 @@ public:
|
|||
explicit Type(const llvm::Record *record) : Type(*record) {}
|
||||
explicit Type(const llvm::DefInit *init);
|
||||
|
||||
// Returns true if this is a variadic type.
|
||||
bool isVariadic() const;
|
||||
|
||||
// Gets the base type of this variadic type.
|
||||
// Precondition: This type is a variadic type.
|
||||
Type getVariadicBaseType() const;
|
||||
|
||||
// Returns the TableGen def name for this type.
|
||||
StringRef getTableGenDefName() const;
|
||||
};
|
||||
|
|
|
@ -84,6 +84,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
|
|||
return attributes[index];
|
||||
}
|
||||
|
||||
bool tblgen::Operator::hasVariadicOperand() const {
|
||||
return !operands.empty() && operands.back().type.isVariadic();
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getArgName(int index) const {
|
||||
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||
return argumentValues->getArgName(index)->getValue();
|
||||
|
@ -179,6 +183,12 @@ void tblgen::Operator::populateOperandsAndAttributes() {
|
|||
Attribute(cast<DefInit>(val.getValue()))});
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = operands.size() - 1; i < e; ++i) {
|
||||
if (operands[i].type.isVariadic())
|
||||
PrintFatalError(def.getLoc(),
|
||||
"only the last operand allowed to be variadic");
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
|
||||
|
|
|
@ -61,3 +61,10 @@ tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) {
|
|||
tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {}
|
||||
|
||||
StringRef tblgen::Type::getTableGenDefName() const { return def->getName(); }
|
||||
|
||||
bool tblgen::Type::isVariadic() const { return def->isSubClassOf("Variadic"); }
|
||||
|
||||
tblgen::Type tblgen::Type::getVariadicBaseType() const {
|
||||
assert(isVariadic() && "must be variadic type");
|
||||
return Type(def->getValueAsDef("baseType"));
|
||||
}
|
||||
|
|
|
@ -228,7 +228,7 @@ void OpEmitter::emitNamedOperands() {
|
|||
)";
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
const auto &operand = op.getOperand(i);
|
||||
if (!operand.name.empty())
|
||||
if (!operand.type.isVariadic() && !operand.name.empty())
|
||||
os << formatv(operandMethods, operand.name, i);
|
||||
}
|
||||
}
|
||||
|
@ -260,8 +260,11 @@ void OpEmitter::emitBuilder() {
|
|||
os << ", Type returnType" << i;
|
||||
|
||||
// Emit parameters for all operands
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i)
|
||||
os << ", Value* " << getArgumentName(op, i);
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
auto &operand = op.getOperand(i);
|
||||
os << (operand.type.isVariadic() ? ", ArrayRef<Value*> " : ", Value* ")
|
||||
<< getArgumentName(op, i);
|
||||
}
|
||||
|
||||
// Emit parameters for all attributes
|
||||
// TODO(antiagainst): Support default initializer for attributes
|
||||
|
@ -283,12 +286,20 @@ void OpEmitter::emitBuilder() {
|
|||
}
|
||||
|
||||
// Push all operands to the result
|
||||
if (op.getNumOperands() > 0) {
|
||||
auto numOperands = op.getNumOperands();
|
||||
bool hasVariadicOperand = op.hasVariadicOperand();
|
||||
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
|
||||
if (numNonVariadicOperands > 0) {
|
||||
OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
|
||||
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
|
||||
for (int i = 1, e = numNonVariadicOperands; i < e; ++i) {
|
||||
os << ", " << getArgumentName(op, i);
|
||||
}
|
||||
os << "});\n";
|
||||
}
|
||||
if (hasVariadicOperand) {
|
||||
OUT(4) << "result->addOperands(" << getArgumentName(op, numOperands - 1)
|
||||
<< ");\n";
|
||||
}
|
||||
|
||||
// Push all attributes to the result
|
||||
for (const auto &namedAttr : op.getAttributes())
|
||||
|
@ -310,7 +321,7 @@ void OpEmitter::emitBuilder() {
|
|||
<< " result->addTypes(resultTypes);\n";
|
||||
|
||||
// Operands
|
||||
OUT(4) << "assert(args.size() == " << op.getNumOperands()
|
||||
OUT(4) << "assert(args.size() == " << numNonVariadicOperands
|
||||
<< "u && \"mismatched number of parameters\");\n"
|
||||
<< " result->addOperands(args);\n\n";
|
||||
|
||||
|
@ -422,9 +433,12 @@ void OpEmitter::emitVerifier() {
|
|||
OUT(4) << "}\n";
|
||||
}
|
||||
|
||||
// TODO: Handle variadic.
|
||||
int opIndex = 0;
|
||||
for (const auto &operand : op.getOperands()) {
|
||||
// TODO: Handle variadic operand verification.
|
||||
if (operand.type.isVariadic())
|
||||
continue;
|
||||
|
||||
// TODO: Commonality between matchers could be extracted to have a more
|
||||
// concise code.
|
||||
if (operand.hasMatcher()) {
|
||||
|
@ -466,38 +480,33 @@ void OpEmitter::emitTraits() {
|
|||
break;
|
||||
}
|
||||
|
||||
// Track explicitly added operand size traits. Note that some ops might
|
||||
// implicitly defines the number of operands via the Argument dag.
|
||||
bool hasVariadicOperands = false;
|
||||
bool hasAtLeastNOperands = false;
|
||||
|
||||
// Add variadic size trait and normal op traits.
|
||||
for (StringRef trait : def.getValueAsListOfStrings("traits")) {
|
||||
if (trait == "VariadicOperands") {
|
||||
hasVariadicOperands = true;
|
||||
} else if (trait.startswith("AtLeastNOperands")) {
|
||||
hasAtLeastNOperands = true;
|
||||
}
|
||||
os << ", OpTrait::" << trait;
|
||||
}
|
||||
|
||||
if ((hasVariadicOperands || hasAtLeastNOperands) && op.getNumOperands() > 0) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
"Operands number definition is not consistent.");
|
||||
}
|
||||
auto numOperands = op.getNumOperands();
|
||||
bool hasVariadicOperand = op.hasVariadicOperand();
|
||||
|
||||
// Add operand size trait.
|
||||
switch (op.getNumOperands()) {
|
||||
case 0:
|
||||
if (!hasVariadicOperands && !hasAtLeastNOperands)
|
||||
os << ", OpTrait::ZeroOperands";
|
||||
break;
|
||||
case 1:
|
||||
os << ", OpTrait::OneOperand";
|
||||
break;
|
||||
default:
|
||||
os << ", OpTrait::NOperands<" << op.getNumOperands() << ">::Impl";
|
||||
break;
|
||||
os << ", OpTrait::";
|
||||
if (hasVariadicOperand) {
|
||||
if (numOperands == 1)
|
||||
os << "VariadicOperands";
|
||||
else
|
||||
os << "AtLeastNOperands<" << (numOperands - 1) << ">::Impl";
|
||||
} else {
|
||||
switch (op.getNumOperands()) {
|
||||
case 0:
|
||||
os << "ZeroOperands";
|
||||
break;
|
||||
case 1:
|
||||
os << "OneOperand";
|
||||
break;
|
||||
default:
|
||||
os << "NOperands<" << numOperands << ">::Impl";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue