Add fallback to native code op builder specification for patterns.

This allow for arbitrarily complex builder patterns which is meant to cover initial cases while the modelling is improved and long tail cases/cases for which expanding the DSL would result in worst overall system.

NFC just sorting the emit replace methods alphabetical in the class and file body.

PiperOrigin-RevId: 231890352
This commit is contained in:
Jacques Pienaar 2019-01-31 17:57:06 -08:00 committed by jpienaar
parent 4161d44bd5
commit 82dc6a878c
4 changed files with 81 additions and 14 deletions

View File

@ -373,7 +373,8 @@ class Op<string mnemonic, list<OpProperty> props = []> {
// Additional, longer human-readable description of what the op does.
string description = "";
// Dag containting the arguments of the op. Default to 0 arguments.
// Dag containting the arguments of the op. Default to 0 arguments. Operands
// to the op need to precede attributes to ops in the argument specification.
dag arguments = (ins);
// The list of results of the op. Default to 0 results.
@ -479,6 +480,24 @@ class tAttr<code transform> {
code attrTransform = transform;
}
// Native code op creation method. This allows performing an arbitrary op
// creation/replacement by invoking a C++ function with the operands and
// attributes. The function specified needs to have the signature:
//
// void f(OperationInst *op, ArrayRef<Value *> operands,
// ArrayRef<Attribute> attrs, PatternRewriter &rewriter);
//
// The operands and attributes are passed to this function in the order of
// the DAG specified. It is the responsibility of this function to replace the
// matched op(s) using the rewriter. This is intended for the long tail op
// creation and replacement.
class cOp<string f> {
// Function to invoke with the given arguments to construct a new op. The
// operands will be passed to the function first followed by the attributes
// (as in the function signature above and required by Op arguments).
string function = f;
}
// Marker used to indicate that no new result op are generated by applying the
// rewrite pattern, so to replace the matched DAG with an existing SSA value.
def replaceWithValue;

View File

@ -104,6 +104,10 @@ public:
// Returns the specified name of the `index`-th argument.
llvm::StringRef getArgName(unsigned index) const;
// Returns the native builder for the pattern.
// Precondition: isNativeCodeBuilder.
llvm::StringRef getNativeCodeBuilder() const;
// Collects all recursively bound arguments involved in the DAG tree rooted
// from this node.
void collectBoundArguments(Pattern *pattern) const;
@ -112,6 +116,10 @@ public:
// value.
bool isReplaceWithValue() const;
// Returns true if this DAG construct is meant to invoke a native code
// constructor.
bool isNativeCodeBuilder() const;
private:
const llvm::DagInit *node; // nullptr means null DagNode
};

View File

@ -94,6 +94,17 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue";
}
bool tblgen::DagNode::isNativeCodeBuilder() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->isSubClassOf("cOp");
}
llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
assert(isNativeCodeBuilder());
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getValueAsString("function");
}
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {
getSourcePattern().collectBoundArguments(this);

View File

@ -84,10 +84,13 @@ private:
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits the C++ statement to replace the matched DAG with an existing value.
void emitReplaceWithExistingValue(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with a new op.
void emitReplaceOpWithNewOp(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with an existing value.
void emitReplaceWithExistingValue(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with a native C++ built
// value.
void emitReplaceWithNativeBuilder(DagNode resultTree);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(Record *constAttr);
@ -263,7 +266,9 @@ void PatternEmitter::emitRewriteMethod() {
auto& s = *static_cast<MatchedState *>(state.get());
)";
if (resultTree.isReplaceWithValue())
if (resultTree.isNativeCodeBuilder())
emitReplaceWithNativeBuilder(resultTree);
else if (resultTree.isReplaceWithValue())
emitReplaceWithExistingValue(resultTree);
else
emitReplaceOpWithNewOp(resultTree);
@ -271,16 +276,6 @@ void PatternEmitter::emitRewriteMethod() {
os << " }\n";
}
void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
if (resultTree.getNumArgs() != 1) {
PrintFatalError(loc, "exactly one argument needed in the result pattern");
}
auto name = resultTree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
}
void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
Operator &resultOp = resultTree.getDialectOp(opMap);
auto numOpArgs =
@ -353,6 +348,40 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
os << "\n );\n";
}
void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
if (resultTree.getNumArgs() != 1) {
PrintFatalError(loc, "exactly one argument needed in the result pattern");
}
auto name = resultTree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
}
void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
os.indent(4) << resultTree.getNativeCodeBuilder() << "(op, {";
const auto &boundedValues = pattern.getSourcePatternBoundArgs();
bool first = true;
bool printingAttr = false;
for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) {
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
const auto &val = boundedValues.find(name);
if (val->second.isAttr() && !printingAttr) {
os << "}, {";
first = true;
printingAttr = true;
}
if (!first)
os << ",";
os << "s." << name;
first = false;
}
if (!printingAttr)
os << "},{";
os << "}, rewriter);\n";
}
void PatternEmitter::emit(StringRef rewriteName, Record *p,
RecordOperatorMap *mapper, raw_ostream &os) {
PatternEmitter(p, mapper, os).emit(rewriteName);