Introduce OpOperandAdaptors and emit them from ODS

When manipulating generic operations, such as in dialect conversion /
rewriting, it is often necessary to view a list of Values as operands to an
operation without creating the operation itself.  The absence of such view
makes dialect conversion patterns, among others, to use magic numbers to obtain
specific operands from a list of rewritten values when converting an operation.
Introduce XOpOperandAdaptor classes that wrap an ArrayRef<Value *> and provide
accessor functions identical to those available in XOp.  This makes it possible
for conversions to use these adaptors to address the operands with names rather
than rely on their position in the list.  The adaptors are generated from ODS
together with the actual operation definitions.

This is another step towards making dialect conversion patterns specific for a
given operation.

Illustrate the approach on conversion patterns in the standard to LLVM dialect
conversion.

PiperOrigin-RevId: 251232899
This commit is contained in:
Alex Zinenko 2019-06-03 08:03:20 -07:00 committed by Mehdi Amini
parent 3ea8575058
commit 252de8eca0
9 changed files with 313 additions and 83 deletions

View File

@ -400,11 +400,13 @@ the other for definitions. The former is generated via the `-gen-op-decls`
command-line option, while the latter is via the `-gen-op-defs` option.
The definition file contains all the op method definitions, which can be
included and enabled by defining `GET_OP_CLASSES`. Besides, it also
contains a comma-separated list of all defined ops, which can be included
and enabled by defining `GET_OP_LIST`.
included and enabled by defining `GET_OP_CLASSES`. For each operation,
OpDefinitionsGen generates an operation class and an
[operand adaptor](#operand-adaptors) class. Besides, it also contains a
comma-separated list of all defined ops, which can be included and enabled by
defining `GET_OP_LIST`.
### Class name and namespaces
#### Class name and namespaces
For each operation, its generated C++ class name is the symbol `def`ed with
TableGen with dialect prefix removed. The first `_` serves as the delimiter.
@ -423,6 +425,36 @@ match exactly with the operation name as explained in
[Operation name](#operation-name). This is to allow flexible naming to satisfy
coding style requirements.
#### Operand adaptors
For each operation, we automatically generate an _operand adaptor_. This class
solves the problem of accessing operands provided as a list of `Value`s without
using "magic" constants. The operand adaptor takes a reference to an array of
`Value *` and provides methods with the same names as those in the operation
class to access them. For example, for a binary arithmethic operation, it may
provide `.lhs()` to access the first operand and `.rhs()` to access the second
operand.
The operand adaptor class lives in the same namespace as the operation class,
and has the name of the operation followed by `OperandAdaptor`. A template
declaration `OperandAdaptor<>` is provided to look up the operand adaptor for
the given operation.
Operand adaptors can be used in function templates that also process operations:
```c++
template <typename BinaryOpTy>
std::pair<Value *, Value *> zip(BinaryOpTy &&op) {
return std::make_pair(op.lhs(), op.rhs());;
}
void process(AddOp op, ArrayRef<Value *> newOperands) {
zip(op);
zip(OperandAdaptor<AddOp>(newOperands));
/*...*/
}
```
## Constraints
Constraint is a core concept in table-driven operation definition: operation

View File

@ -49,6 +49,14 @@ class RewritePattern;
class Type;
class Value;
/// This is an adaptor from a list of values to named operands of OpTy. In a
/// generic operation context, e.g., in dialect conversions, an ordered array of
/// `Value`s is treated as operands of `OpTy`. This adaptor takes a reference
/// to the array and provides accessors with the same names as `OpTy` for
/// operands. This makes possible to create function templates that operate on
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
/// This is a vector that owns the patterns inside of it.
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;

View File

@ -663,7 +663,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
%3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
}];
let arguments = (ins AnyMemRef);
let arguments = (ins AnyMemRef:$source);
let results = (outs AnyMemRef);
let extraClassDeclaration = [{

View File

@ -157,6 +157,12 @@ public:
// Returns this op's extra class declaration code.
StringRef getExtraClassDeclaration() const;
// Returns the Tablegen definition this operator was constructed from.
// TODO(antiagainst,zinenko): do not expose the TableGen record, this is a
// temporary solution to OpEmitter requiring a Record because Operator does
// not provide enough methods.
const llvm::Record &getDef() const;
private:
// Populates the vectors containing operands, attributes, results and traits.
void populateOpStructure();

View File

@ -503,6 +503,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
Function *freeFunc =
@ -513,11 +514,12 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
}
auto type = operands[0]->getType().cast<LLVM::LLVMType>();
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
Value *bufferPtr = extractMemRefElementPtr(
rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
Value *bufferPtr =
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
@ -542,13 +544,14 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
auto targetType = memRefCastOp.getType();
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
// Copy the data buffer pointer.
auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
Value *buffer =
extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
elementTypePtr, sourceType.hasStaticShape());
// Account for static memrefs as target types
if (targetType.hasStaticShape())
@ -583,7 +586,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
operands[0], // NB: dynamic memref
transformed.source(), // NB: dynamic memref
getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
@ -612,8 +615,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "expected exactly one operand");
auto dimOp = cast<DimOp>(op);
OperandAdaptor<DimOp> transformed(operands);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
auto shape = type.getShape();
@ -630,7 +633,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
++position;
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), operands[0],
op, getIndexType(), transformed.memrefOrTensor(),
getIntegerArrayAttr(rewriter, position));
} else {
rewriter.replaceOp(
@ -759,10 +762,11 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
OperandAdaptor<LoadOp> transformed(operands);
auto type = loadOp.getMemRefType();
Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
operands.drop_front(), rewriter, getModule());
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
auto elementType = lowering.convertType(type.getElementType());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
@ -778,10 +782,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
OperandAdaptor<StoreOp> transformed(operands);
Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
operands.drop_front(2), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, operands[0], dataPtr);
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
}
};

View File

@ -87,6 +87,8 @@ StringRef tblgen::Operator::getExtraClassDeclaration() const {
return def.getValueAsString(attr);
}
const llvm::Record &tblgen::Operator::getDef() const { return def; }
tblgen::TypeConstraint
tblgen::Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results");

View File

@ -39,9 +39,19 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
// CHECK-LABEL: NS::AOp declarations
// CHECK: class AOpOperandAdaptor {
// CHECK: public:
// CHECK: AOpOperandAdaptor(ArrayRef<Value *> values);
// CHECK: Value *a();
// CHECK: ArrayRef<Value *> b();
// CHECK: private:
// CHECK: ArrayRef<Value *> tblgen_operands;
// CHECK: };
// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl> {
// CHECK: public:
// CHECK: using Op::Op;
// CHECK: using OperandAdaptor = AOpOperandAdaptor;
// CHECK: static StringRef getOperationName();
// CHECK: Value *a();
// CHECK: Operation::operand_range b();

View File

@ -14,6 +14,12 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK-LABEL: OpA definitions
// CHECK: OpAOperandAdaptor::OpAOperandAdaptor
// CHECK-NEXT: tblgen_operands = values
// CHECK: OpAOperandAdaptor::input
// CHECK-NEXT: return tblgen_operands[0]
// CHECK: void OpA::build
// CHECK-SAME: Value *input
// CHECK: tblgen_state->operands.push_back(input);
@ -40,6 +46,16 @@ def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<AnyTensor>:$input1, Variadic<AnyTensor>:$input2);
}
// CHECK-LABEL: OpCOperandAdaptor::input1
// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 0) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
// CHECK-LABEL: OpCOperandAdaptor::input2
// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 0) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Operation::operand_range OpC::input1()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 0) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
@ -58,6 +74,21 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
}
// CHECK-LABEL: OpDOperandAdaptor::input1() {
// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 1) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Value *OpDOperandAdaptor::input2() {
// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 1) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return tblgen_operands[offset];
// CHECK-LABEL: OpDOperandAdaptor::input3() {
// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 1) / 2;
// CHECK-NEXT: offset = 1 + variadicOperandSize * 1;
// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Operation::operand_range OpD::input1()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
@ -66,7 +97,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-LABEL: Value *OpD::input2()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return this->getOperand(offset);
// CHECK-NEXT: return this->getOperation()->getOperand(offset);
// CHECK-LABEL: Operation::operand_range OpD::input3()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2;
@ -82,6 +113,21 @@ def OpE : NS_Op<"one_variadic_among_multi_normal_inputs_op", []> {
let arguments = (ins AnyTensor:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3, AnyTensor:$input4, AnyTensor:$input5);
}
// CHECK-LABEL: Value *OpEOperandAdaptor::input1() {
// CHECK-NEXT: return tblgen_operands[0];
// CHECK-LABEL: Value *OpEOperandAdaptor::input2() {
// CHECK-NEXT: return tblgen_operands[1];
// CHECK-LABEL: OpEOperandAdaptor::input3() {
// CHECK-NEXT: return {std::next(tblgen_operands.begin(), 2), std::next(tblgen_operands.begin(), 2 + tblgen_operands.size() - 4)};
// CHECK-LABEL: Value *OpEOperandAdaptor::input4() {
// CHECK-NEXT: return tblgen_operands[tblgen_operands.size() - 2];
// CHECK-LABEL: Value *OpEOperandAdaptor::input5() {
// CHECK-NEXT: return tblgen_operands[tblgen_operands.size() - 1];
// CHECK-LABEL: Value *OpE::input1()
// CHECK-NEXT: return this->getOperation()->getOperand(0);

View File

@ -110,8 +110,8 @@ public:
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
// Returns true if the given C++ `type` ends with '&' or '*'.
static bool endsWithRefOrPtr(StringRef type);
// Returns true if the given C++ `type` ends with '&' or '*', or is empty.
static bool elideSpaceAfterType(StringRef type);
std::string returnType;
std::string methodName;
@ -142,7 +142,8 @@ public:
// querying properties.
enum Property {
MP_None = 0x0,
MP_Static = 0x1, // Static method
MP_Static = 0x1, // Static method
MP_Constructor = 0x2, // Constructor
};
OpMethod(StringRef retType, StringRef name, StringRef params,
@ -168,19 +169,22 @@ private:
OpMethodBody methodBody;
};
// Class for holding an op for C++ code emission
class OpClass {
// A class used to emit C++ classes from Tablegen. Contains a list of public
// methods and a list of private fields to be emitted.
class Class {
public:
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
explicit Class(StringRef name);
// Adds an op trait.
void addTrait(Twine trait);
// Creates a new method in this op's class.
// Creates a new method in this class.
OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
OpMethod::Property = OpMethod::MP_None,
bool declOnly = false);
OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
// Creates a new field in this class.
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
// Writes this op's class as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the method definitions in this op's class to the given `os`.
@ -189,11 +193,27 @@ public:
// Returns the C++ class name of the op.
StringRef getClassName() const { return className; }
protected:
std::string className;
SmallVector<OpMethod, 8> methods;
SmallVector<std::string, 4> fields;
};
// Class for holding an op for C++ code emission
class OpClass : public Class {
public:
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
// Adds an op trait.
void addTrait(Twine trait);
// Writes this op's class as a declaration to the given `os`. Redefines
// Class::writeDeclTo to also emit traits and extra class declarations.
void writeDeclTo(raw_ostream &os) const;
private:
StringRef className;
StringRef extraClassDeclaration;
SmallVector<std::string, 4> traits;
SmallVector<OpMethod, 8> methods;
};
} // end anonymous namespace
@ -202,7 +222,7 @@ OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
: returnType(retType), methodName(name), parameters(params) {}
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << methodName
os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName
<< "(" << parameters << ")";
}
@ -224,13 +244,13 @@ void OpMethodSignature::writeDefTo(raw_ostream &os,
return result;
};
os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << namePrefix
os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "("
<< removeParamDefaultValue(parameters) << ")";
}
bool OpMethodSignature::endsWithRefOrPtr(StringRef type) {
return type.endswith("&") || type.endswith("*");
bool OpMethodSignature::elideSpaceAfterType(StringRef type) {
return type.empty() || type.endswith("&") || type.endswith("*");
}
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
@ -287,41 +307,71 @@ void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
os << "}";
}
Class::Class(StringRef name) : className(name) {}
OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params,
OpMethod::Property property, bool declOnly) {
methods.emplace_back(retType, name, params, property, declOnly);
return methods.back();
}
OpMethod &Class::newConstructor(StringRef params, bool declOnly) {
return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
declOnly);
}
void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
std::string varName = formatv("{0} {1}", type, name).str();
std::string field = defaultValue.empty()
? varName
: formatv("{0} = {1}", varName, defaultValue).str();
fields.push_back(std::move(field));
}
void Class::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " {\n";
os << "public:\n";
for (const auto &method : methods) {
method.writeDeclTo(os);
os << '\n';
}
os << '\n';
os << "private:\n";
for (const auto &field : fields)
os.indent(2) << field << ";\n";
os << "};\n";
}
void Class::writeDefTo(raw_ostream &os) const {
for (const auto &method : methods) {
method.writeDefTo(os, className);
os << "\n\n";
}
}
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: className(name), extraClassDeclaration(extraClassDeclaration) {}
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
// Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly.
void OpClass::addTrait(Twine trait) {
traits.push_back(("OpTrait::" + trait).str());
}
OpMethod &OpClass::newMethod(StringRef retType, StringRef name,
StringRef params, OpMethod::Property property,
bool declOnly) {
methods.emplace_back(retType, name, params, property, declOnly);
return methods.back();
}
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public Op<" << className;
for (const auto &trait : traits)
os << ", " << trait;
os << "> {\npublic:\n";
os << " using Op::Op;\n";
os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";
for (const auto &method : methods) {
method.writeDeclTo(os);
os << "\n";
}
// TODO: Add line control markers to make errors easier to debug.
os << extraClassDeclaration << "\n";
os << "};";
}
void OpClass::writeDefTo(raw_ostream &os) const {
for (const auto &method : methods) {
method.writeDefTo(os, className);
os << "\n\n";
}
if (!extraClassDeclaration.empty())
os << extraClassDeclaration << "\n";
os << "};\n";
}
//===----------------------------------------------------------------------===//
@ -332,11 +382,11 @@ namespace {
// Helper class to emit a record into the given output stream.
class OpEmitter {
public:
static void emitDecl(const Record &def, raw_ostream &os);
static void emitDef(const Record &def, raw_ostream &os);
static void emitDecl(const Operator &op, raw_ostream &os);
static void emitDef(const Operator &op, raw_ostream &os);
private:
OpEmitter(const Record &def);
OpEmitter(const Operator &op);
void emitDecl(raw_ostream &os);
void emitDef(raw_ostream &os);
@ -385,6 +435,8 @@ private:
void genOpNameGetter();
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
// it should rather go through the Operator for better abstraction.
const Record &def;
// The wrapper operator class for querying information from this op.
@ -398,8 +450,8 @@ private:
};
} // end anonymous namespace
OpEmitter::OpEmitter(const Record &def)
: def(def), op(def),
OpEmitter::OpEmitter(const Operator &op)
: def(op.getDef()), op(op),
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
verifyCtx.withOp("(*this->getOperation())");
@ -418,22 +470,19 @@ OpEmitter::OpEmitter(const Record &def)
genFolderDecls();
}
void OpEmitter::emitDecl(const Record &def, raw_ostream &os) {
OpEmitter(def).emitDecl(os);
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpEmitter(op).emitDecl(os);
}
void OpEmitter::emitDef(const Record &def, raw_ostream &os) {
OpEmitter(def).emitDef(os);
void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
OpEmitter(op).emitDef(os);
}
void OpEmitter::emitDecl(raw_ostream &os) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
opClass.writeDeclTo(os);
}
void OpEmitter::emitDef(raw_ostream &os) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
opClass.writeDefTo(os);
}
@ -480,7 +529,21 @@ void OpEmitter::genAttrGetters() {
}
}
void OpEmitter::genNamedOperandGetters() {
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value *` and each
// element in the range must also be `Value *`); use `rangeBeginCall` to get an
// iterator to the beginning of the operand range; use `rangeSizeCall` to obtain
// the number of operands. `getOperandCallPattern` contains the code necessary
// to obtain a single operand whose position will be substituted instead of
// "{0}" marker in the pattern. Note that the pattern should work for any kind
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef rangeType,
StringRef rangeBeginCall,
StringRef rangeSizeCall,
StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
const int numVariadicOperands = op.getNumVariadicOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
@ -499,20 +562,22 @@ void OpEmitter::genNamedOperandGetters() {
continue;
if (operand.isVariadic()) {
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << formatv(
" return {{std::next(operand_begin(), {0}), "
"std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
i, numNormalOperands);
" return {{std::next({2}, {0}), std::next({2}, {0} + {3} - {1})};",
i, numNormalOperands, rangeBeginCall, rangeSizeCall);
emittedVariadicOperand = true;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << " return this->getOperation()->getOperand(";
if (emittedVariadicOperand)
m.body() << "this->getNumOperands() - " << numOperands - i;
else
m.body() << i;
m.body() << ");\n";
auto operandIndex =
emittedVariadicOperand
? formatv("{0} - {1}", rangeSizeCall, numOperands - i).str()
: std::to_string(i);
m.body() << " return "
<< formatv(getOperandCallPattern.data(), operandIndex)
<< ";\n";
}
}
return;
@ -535,27 +600,36 @@ void OpEmitter::genNamedOperandGetters() {
continue;
const char *code = R"(
int variadicOperandSize = (this->getNumOperands() - {0}) / {1};
int variadicOperandSize = ({4} - {0}) / {1};
int offset = {2} + variadicOperandSize * {3};
return )";
auto sizeAndOffset =
formatv(code, numNormalOperands, numVariadicOperands,
emittedNormalOperands, emittedVariadicOperands);
emittedNormalOperands, emittedVariadicOperands, rangeSizeCall);
if (operand.isVariadic()) {
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
m.body() << sizeAndOffset
<< "{std::next(operand_begin(), offset), "
"std::next(operand_begin(), offset + variadicOperandSize)};";
auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << sizeAndOffset << "{std::next(" << rangeBeginCall
<< ", offset), std::next(" << rangeBeginCall
<< ", offset + variadicOperandSize)};";
++emittedVariadicOperands;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << sizeAndOffset << "this->getOperand(offset);";
m.body() << sizeAndOffset
<< formatv(getOperandCallPattern.data(), "offset") << ";";
++emittedNormalOperands;
}
}
}
void OpEmitter::genNamedOperandGetters() {
generateNamedOperandGetters(
op, opClass, /*rangeType=*/"Operation::operand_range",
/*rangeBeginCall=*/"operand_begin()",
/*rangeSizeCall=*/"this->getNumOperands()",
/*getOperandCallPattern=*/"this->getOperation()->getOperand({0})");
}
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariadicResults();
@ -1103,15 +1177,61 @@ void OpEmitter::genOpNameGetter() {
method.body() << " return \"" << op.getOperationName() << "\";\n";
}
//===----------------------------------------------------------------------===//
// OpOperandAdaptor emitter
//===----------------------------------------------------------------------===//
namespace {
// Helper class to emit Op operand adaptors to an output stream. Operand
// adaptors are wrappers around ArrayRef<Value *> that provide named operand
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
static void emitDecl(const Operator &op, raw_ostream &os);
static void emitDef(const Operator &op, raw_ostream &os);
private:
explicit OpOperandAdaptorEmitter(const Operator &op);
Class adapterClass;
};
} // end namespace
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
: adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
adapterClass.newField("ArrayRef<Value *>", "tblgen_operands");
auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values");
constructor.body() << " tblgen_operands = values;\n";
generateNamedOperandGetters(op, adapterClass,
/*rangeType=*/"ArrayRef<Value *>",
/*rangeBeginCall=*/"tblgen_operands.begin()",
/*rangeSizeCall=*/"tblgen_operands.size()",
/*getOperandCallPattern=*/"tblgen_operands[{0}]");
}
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
}
// Emits the opcode enum and op classes.
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
bool emitDecl) {
IfDefScope scope("GET_OP_CLASSES", os);
for (auto *def : defs) {
Operator op(*def);
if (emitDecl) {
OpEmitter::emitDecl(*def, os);
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os);
} else {
OpEmitter::emitDef(*def, os);
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os);
}
}
}