[TableGen] Capture bound source ops in PatternState

This allows accessing those bound source ops in result patterns, which can be
    useful for invoking native C++ op creation.

    We bind the op entirely here because ops can have multiple results. Design a
    approach to bind to a specific result is not the concern of this commit.

--

PiperOrigin-RevId: 244724750
This commit is contained in:
Lei Zhang 2019-04-22 13:40:30 -07:00 committed by Mehdi Amini
parent f7f2760c30
commit 09b623aa93
5 changed files with 46 additions and 69 deletions

View File

@ -205,23 +205,15 @@ public:
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
// Returns true if an argument with the given `name` is bound in source
// pattern.
bool isArgBoundInSourcePattern(llvm::StringRef name) const;
// Returns true if an argument with the given `name` is bound as result of
// op in pattern.
bool isResultBoundInSourcePattern(llvm::StringRef name) const;
// Checks whether an argument with the given `name` is bound in source
// pattern. Prints fatal error if not; does nothing otherwise.
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
// Checks whether an argument or op with the given `name` is bound in
// source pattern. Prints fatal error if not; does nothing otherwise.
void ensureBoundInSourcePattern(llvm::StringRef name) const;
// Returns a reference to all the bound arguments in the source pattern.
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
// Returns a reference to all the bound results in the source pattern.
llvm::StringSet<> &getSourcePatternBoundResults();
// Returns a reference to all the bound ops in the source pattern.
llvm::StringSet<> &getSourcePatternBoundOps();
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@ -247,11 +239,11 @@ private:
// All operators.
RecordOperatorMap *recordOpMap;
// All bound arguments.
// All bound op arguments.
llvm::StringMap<Argument> boundArguments;
// All bound results.
llvm::StringSet<> boundResults;
// All bound ops.
llvm::StringSet<> boundOps;
};
} // end namespace tblgen

View File

@ -186,17 +186,9 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
bool tblgen::Pattern::isArgBoundInSourcePattern(llvm::StringRef name) const {
return boundArguments.find(name) != boundArguments.end();
}
bool tblgen::Pattern::isResultBoundInSourcePattern(llvm::StringRef name) const {
return boundResults.count(name);
}
void tblgen::Pattern::ensureArgBoundInSourcePattern(
llvm::StringRef name) const {
if (!isArgBoundInSourcePattern(name))
void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
if (boundArguments.find(name) == boundArguments.end() &&
boundOps.find(name) == boundOps.end())
PrintFatalError(def.getLoc(),
Twine("referencing unbound variable '") + name + "'");
}
@ -206,8 +198,8 @@ tblgen::Pattern::getSourcePatternBoundArgs() {
return boundArguments;
}
llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundResults() {
return boundResults;
llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundOps() {
return boundOps;
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
@ -268,7 +260,7 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) {
// results generated from this op. It should be remembered as bound results.
auto treeName = tree.getOpName();
if (!treeName.empty())
boundResults.insert(treeName);
boundOps.insert(treeName);
// TODO(jpienaar): Expand to multiple matches.
for (unsigned i = 0; i != numTreeArgs; ++i) {

View File

@ -18,7 +18,7 @@ def OpC : Op<"op_c", []> {
}
def OpD : Op<"op_d", []> {
let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr);
let arguments = (ins I32:$input1, I32:$input2, I32:$input3, I32Attr:$attr);
let results = (outs I32:$result);
}
@ -26,7 +26,7 @@ def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
def : Pattern<(OpA:$res_a $operand, $attr),
[(OpC:$res_c (OpB:$res_b $operand)),
(OpD $res_b, $res_c, $attr)],
(OpD $res_b, $res_c, $res_a, $attr)],
[(hasOneUse $res_a)]>;
// CHECK-LABEL: GeneratedConvert0
@ -34,20 +34,20 @@ def : Pattern<(OpA:$res_a $operand, $attr),
// Test struct for bound arguments
// ---
// CHECK: struct MatchedState : public PatternState
// CHECK: Value* operand;
// CHECK: Value *operand;
// CHECK: IntegerAttr attr;
// CHECK: Operation *res_a;
// Test bound arguments/results in source pattern
// ---
// CHECK: PatternMatchResult match
// CHECK: auto state = llvm::make_unique<MatchedState>();
// CHECK: auto &s = *state;
// CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a;
// CHECK: tblgen_res_a = op0;
// CHECK: s.res_a = op0;
// CHECK: s.operand = op0->getOperand(0);
// CHECK: attr = op0->getAttrOfType<IntegerAttr>("attr");
// CHECK: s.attr = attr;
// CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure();
// CHECK: if (!(s.res_a->hasOneUse())) return matchFailure();
// Test bound results in result pattern
// ---
@ -59,4 +59,5 @@ def : Pattern<(OpA:$res_a $operand, $attr),
// CHECK: auto vOpD0 = rewriter.create<OpD>(
// CHECK: /*input1=*/res_b,
// CHECK: /*input2=*/res_c,
// CHECK: /*input3=*/s.res_a,
// CHECK: /*attr=*/s.attr

View File

@ -29,7 +29,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
// CHECK: struct MatchedState : public PatternState {
// CHECK: Value* input;
// CHECK: Value *input;
// CHECK: IntegerAttr attr;
// CHECK: };

View File

@ -41,19 +41,12 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
static const char *const tblgenNamePrefix = "tblgen_";
// Returns the bound value for the given op result `symbol`.
static Twine getBoundResult(const StringRef &symbol) {
return tblgenNamePrefix + symbol;
}
// Returns the bound value for the given op argument `symbol`.
// Returns the bound symbol for the given op argument or op named `symbol`.
//
// Arguments bound in the source pattern are grouped into a transient
// `PatternState` struct. This struct can be accessed in both `match()` and
// `rewrite()` via the local variable named as `s`.
static Twine getBoundArgument(const StringRef &symbol) {
// Arguments and ops bound in the source pattern are grouped into a
// transient `PatternState` struct. This struct can be accessed in both
// `match()` and `rewrite()` via the local variable named as `s`.
static Twine getBoundSymbol(const StringRef &symbol) {
return Twine("s.") + symbol;
}
@ -85,7 +78,7 @@ namespace {
class PatternSymbolResolver {
public:
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcResults);
const StringSet<> &srcOperations);
// Marks the given `symbol` as bound. Returns false if the `symbol` is
// already bound.
@ -105,8 +98,8 @@ private:
} // end anonymous namespace
PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcResults)
: sourceArguments(srcArgs), sourceOps(srcResults) {}
const StringSet<> &srcOperations)
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
bool PatternSymbolResolver::add(StringRef symbol) {
return resultOps.insert(symbol).second;
@ -121,12 +114,12 @@ std::string PatternSymbolResolver::query(StringRef symbol) const {
{
auto it = sourceArguments.find(symbol);
if (it != sourceArguments.end())
return getBoundArgument(symbol).str();
return getBoundSymbol(symbol).str();
}
{
auto it = sourceOps.find(symbol);
if (it != sourceOps.end())
return getBoundResult(symbol).str();
return getBoundSymbol(symbol).str();
}
return {};
}
@ -235,7 +228,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
symbolResolver(pattern.getSourcePatternBoundArgs(),
pattern.getSourcePatternBoundResults()),
pattern.getSourcePatternBoundOps()),
nextValueId(0), os(os) {
matchCtx.withBuilder("mlir::Builder(ctx)");
rewriteCtx.withBuilder("rewriter");
@ -274,7 +267,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable.
auto name = tree.getOpName();
if (!name.empty())
os.indent(indent) << formatv("{0} = op{1};\n", getBoundResult(name), depth);
os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
@ -330,7 +323,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << getBoundArgument(name) << " = op" << depth
os.indent(indent) << getBoundSymbol(name) << " = op" << depth
<< "->getOperand(" << index << ");\n";
}
}
@ -380,7 +373,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << getBoundArgument(name) << " = attr;\n";
os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
}
indent -= 2;
@ -405,10 +398,6 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
}
}
for (auto &res : pattern.getSourcePatternBoundResults())
os.indent(4) << formatv("mlir::Operation* {0}; (void){0};\n",
getBoundResult(res.first()));
emitOpMatch(tree, 0);
for (auto &appliedConstraint : pattern.getConstraints()) {
@ -472,9 +461,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
<< ";\n";
} else {
os.indent(4) << "Value* " << fieldName << ";\n";
os.indent(4) << "Value *" << fieldName << ";\n";
}
}
for (const auto &result : pattern.getSourcePatternBoundOps()) {
os.indent(4) << "Operation *" << result.getKey() << ";\n";
}
os << " };\n";
emitMatchMethod(tree);
@ -564,9 +556,9 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
}
auto name = tree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
pattern.ensureBoundInSourcePattern(name);
return getBoundArgument(name).str();
return getBoundSymbol(name).str();
}
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
@ -587,8 +579,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
auto enumCase = leaf.getAsEnumAttrCase();
return handleConstantAttr(enumCase, enumCase.getSymbol());
}
pattern.ensureArgBoundInSourcePattern(argName);
std::string result = getBoundArgument(argName).str();
pattern.ensureBoundInSourcePattern(argName);
std::string result = getBoundSymbol(argName).str();
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return result;
}
@ -758,7 +750,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
bool printingAttr = false;
for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) {
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
pattern.ensureBoundInSourcePattern(name);
const auto &val = boundedValues.find(name);
if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
os << "}, {";
@ -767,7 +759,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
}
if (!first)
os << ",";
os << getBoundArgument(name);
os << getBoundSymbol(name);
first = false;
}
if (!printingAttr)