forked from OSchip/llvm-project
[TableGen] Capture bound source ops in PatternState
This allows accessing those bound source ops in result patterns, which can be useful for invoking native C++ op creation. We bind the op entirely here because ops can have multiple results. Design a approach to bind to a specific result is not the concern of this commit. -- PiperOrigin-RevId: 244724750
This commit is contained in:
parent
f7f2760c30
commit
09b623aa93
|
@ -205,23 +205,15 @@ public:
|
|||
// Returns the DAG tree root node of the `index`-th result pattern.
|
||||
DagNode getResultPattern(unsigned index) const;
|
||||
|
||||
// Returns true if an argument with the given `name` is bound in source
|
||||
// pattern.
|
||||
bool isArgBoundInSourcePattern(llvm::StringRef name) const;
|
||||
|
||||
// Returns true if an argument with the given `name` is bound as result of
|
||||
// op in pattern.
|
||||
bool isResultBoundInSourcePattern(llvm::StringRef name) const;
|
||||
|
||||
// Checks whether an argument with the given `name` is bound in source
|
||||
// pattern. Prints fatal error if not; does nothing otherwise.
|
||||
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
|
||||
// Checks whether an argument or op with the given `name` is bound in
|
||||
// source pattern. Prints fatal error if not; does nothing otherwise.
|
||||
void ensureBoundInSourcePattern(llvm::StringRef name) const;
|
||||
|
||||
// Returns a reference to all the bound arguments in the source pattern.
|
||||
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
|
||||
|
||||
// Returns a reference to all the bound results in the source pattern.
|
||||
llvm::StringSet<> &getSourcePatternBoundResults();
|
||||
// Returns a reference to all the bound ops in the source pattern.
|
||||
llvm::StringSet<> &getSourcePatternBoundOps();
|
||||
|
||||
// Returns the op that the root node of the source pattern matches.
|
||||
const Operator &getSourceRootOp();
|
||||
|
@ -247,11 +239,11 @@ private:
|
|||
// All operators.
|
||||
RecordOperatorMap *recordOpMap;
|
||||
|
||||
// All bound arguments.
|
||||
// All bound op arguments.
|
||||
llvm::StringMap<Argument> boundArguments;
|
||||
|
||||
// All bound results.
|
||||
llvm::StringSet<> boundResults;
|
||||
// All bound ops.
|
||||
llvm::StringSet<> boundOps;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
|
|
|
@ -186,17 +186,9 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
|
|||
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||
}
|
||||
|
||||
bool tblgen::Pattern::isArgBoundInSourcePattern(llvm::StringRef name) const {
|
||||
return boundArguments.find(name) != boundArguments.end();
|
||||
}
|
||||
|
||||
bool tblgen::Pattern::isResultBoundInSourcePattern(llvm::StringRef name) const {
|
||||
return boundResults.count(name);
|
||||
}
|
||||
|
||||
void tblgen::Pattern::ensureArgBoundInSourcePattern(
|
||||
llvm::StringRef name) const {
|
||||
if (!isArgBoundInSourcePattern(name))
|
||||
void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
|
||||
if (boundArguments.find(name) == boundArguments.end() &&
|
||||
boundOps.find(name) == boundOps.end())
|
||||
PrintFatalError(def.getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
}
|
||||
|
@ -206,8 +198,8 @@ tblgen::Pattern::getSourcePatternBoundArgs() {
|
|||
return boundArguments;
|
||||
}
|
||||
|
||||
llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundResults() {
|
||||
return boundResults;
|
||||
llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundOps() {
|
||||
return boundOps;
|
||||
}
|
||||
|
||||
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
|
||||
|
@ -268,7 +260,7 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) {
|
|||
// results generated from this op. It should be remembered as bound results.
|
||||
auto treeName = tree.getOpName();
|
||||
if (!treeName.empty())
|
||||
boundResults.insert(treeName);
|
||||
boundOps.insert(treeName);
|
||||
|
||||
// TODO(jpienaar): Expand to multiple matches.
|
||||
for (unsigned i = 0; i != numTreeArgs; ++i) {
|
||||
|
|
|
@ -18,7 +18,7 @@ def OpC : Op<"op_c", []> {
|
|||
}
|
||||
|
||||
def OpD : Op<"op_d", []> {
|
||||
let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr);
|
||||
let arguments = (ins I32:$input1, I32:$input2, I32:$input3, I32Attr:$attr);
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,7 @@ def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
|
|||
|
||||
def : Pattern<(OpA:$res_a $operand, $attr),
|
||||
[(OpC:$res_c (OpB:$res_b $operand)),
|
||||
(OpD $res_b, $res_c, $attr)],
|
||||
(OpD $res_b, $res_c, $res_a, $attr)],
|
||||
[(hasOneUse $res_a)]>;
|
||||
|
||||
// CHECK-LABEL: GeneratedConvert0
|
||||
|
@ -34,20 +34,20 @@ def : Pattern<(OpA:$res_a $operand, $attr),
|
|||
// Test struct for bound arguments
|
||||
// ---
|
||||
// CHECK: struct MatchedState : public PatternState
|
||||
// CHECK: Value* operand;
|
||||
// CHECK: Value *operand;
|
||||
// CHECK: IntegerAttr attr;
|
||||
// CHECK: Operation *res_a;
|
||||
|
||||
// Test bound arguments/results in source pattern
|
||||
// ---
|
||||
// CHECK: PatternMatchResult match
|
||||
// CHECK: auto state = llvm::make_unique<MatchedState>();
|
||||
// CHECK: auto &s = *state;
|
||||
// CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a;
|
||||
// CHECK: tblgen_res_a = op0;
|
||||
// CHECK: s.res_a = op0;
|
||||
// CHECK: s.operand = op0->getOperand(0);
|
||||
// CHECK: attr = op0->getAttrOfType<IntegerAttr>("attr");
|
||||
// CHECK: s.attr = attr;
|
||||
// CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure();
|
||||
// CHECK: if (!(s.res_a->hasOneUse())) return matchFailure();
|
||||
|
||||
// Test bound results in result pattern
|
||||
// ---
|
||||
|
@ -59,4 +59,5 @@ def : Pattern<(OpA:$res_a $operand, $attr),
|
|||
// CHECK: auto vOpD0 = rewriter.create<OpD>(
|
||||
// CHECK: /*input1=*/res_b,
|
||||
// CHECK: /*input2=*/res_c,
|
||||
// CHECK: /*input3=*/s.res_a,
|
||||
// CHECK: /*attr=*/s.attr
|
||||
|
|
|
@ -29,7 +29,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
|
|||
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
|
||||
|
||||
// CHECK: struct MatchedState : public PatternState {
|
||||
// CHECK: Value* input;
|
||||
// CHECK: Value *input;
|
||||
// CHECK: IntegerAttr attr;
|
||||
// CHECK: };
|
||||
|
||||
|
|
|
@ -41,19 +41,12 @@ using namespace llvm;
|
|||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
static const char *const tblgenNamePrefix = "tblgen_";
|
||||
|
||||
// Returns the bound value for the given op result `symbol`.
|
||||
static Twine getBoundResult(const StringRef &symbol) {
|
||||
return tblgenNamePrefix + symbol;
|
||||
}
|
||||
|
||||
// Returns the bound value for the given op argument `symbol`.
|
||||
// Returns the bound symbol for the given op argument or op named `symbol`.
|
||||
//
|
||||
// Arguments bound in the source pattern are grouped into a transient
|
||||
// `PatternState` struct. This struct can be accessed in both `match()` and
|
||||
// `rewrite()` via the local variable named as `s`.
|
||||
static Twine getBoundArgument(const StringRef &symbol) {
|
||||
// Arguments and ops bound in the source pattern are grouped into a
|
||||
// transient `PatternState` struct. This struct can be accessed in both
|
||||
// `match()` and `rewrite()` via the local variable named as `s`.
|
||||
static Twine getBoundSymbol(const StringRef &symbol) {
|
||||
return Twine("s.") + symbol;
|
||||
}
|
||||
|
||||
|
@ -85,7 +78,7 @@ namespace {
|
|||
class PatternSymbolResolver {
|
||||
public:
|
||||
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
||||
const StringSet<> &srcResults);
|
||||
const StringSet<> &srcOperations);
|
||||
|
||||
// Marks the given `symbol` as bound. Returns false if the `symbol` is
|
||||
// already bound.
|
||||
|
@ -105,8 +98,8 @@ private:
|
|||
} // end anonymous namespace
|
||||
|
||||
PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
||||
const StringSet<> &srcResults)
|
||||
: sourceArguments(srcArgs), sourceOps(srcResults) {}
|
||||
const StringSet<> &srcOperations)
|
||||
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
|
||||
|
||||
bool PatternSymbolResolver::add(StringRef symbol) {
|
||||
return resultOps.insert(symbol).second;
|
||||
|
@ -121,12 +114,12 @@ std::string PatternSymbolResolver::query(StringRef symbol) const {
|
|||
{
|
||||
auto it = sourceArguments.find(symbol);
|
||||
if (it != sourceArguments.end())
|
||||
return getBoundArgument(symbol).str();
|
||||
return getBoundSymbol(symbol).str();
|
||||
}
|
||||
{
|
||||
auto it = sourceOps.find(symbol);
|
||||
if (it != sourceOps.end())
|
||||
return getBoundResult(symbol).str();
|
||||
return getBoundSymbol(symbol).str();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
@ -235,7 +228,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
|||
raw_ostream &os)
|
||||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
|
||||
symbolResolver(pattern.getSourcePatternBoundArgs(),
|
||||
pattern.getSourcePatternBoundResults()),
|
||||
pattern.getSourcePatternBoundOps()),
|
||||
nextValueId(0), os(os) {
|
||||
matchCtx.withBuilder("mlir::Builder(ctx)");
|
||||
rewriteCtx.withBuilder("rewriter");
|
||||
|
@ -274,7 +267,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
// If the operand's name is set, set to that variable.
|
||||
auto name = tree.getOpName();
|
||||
if (!name.empty())
|
||||
os.indent(indent) << formatv("{0} = op{1};\n", getBoundResult(name), depth);
|
||||
os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto opArg = op.getArg(i);
|
||||
|
@ -330,7 +323,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
|||
// Capture the value
|
||||
auto name = tree.getArgName(index);
|
||||
if (!name.empty()) {
|
||||
os.indent(indent) << getBoundArgument(name) << " = op" << depth
|
||||
os.indent(indent) << getBoundSymbol(name) << " = op" << depth
|
||||
<< "->getOperand(" << index << ");\n";
|
||||
}
|
||||
}
|
||||
|
@ -380,7 +373,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
|||
// Capture the value
|
||||
auto name = tree.getArgName(index);
|
||||
if (!name.empty()) {
|
||||
os.indent(indent) << getBoundArgument(name) << " = attr;\n";
|
||||
os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
|
||||
}
|
||||
|
||||
indent -= 2;
|
||||
|
@ -405,10 +398,6 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
}
|
||||
}
|
||||
|
||||
for (auto &res : pattern.getSourcePatternBoundResults())
|
||||
os.indent(4) << formatv("mlir::Operation* {0}; (void){0};\n",
|
||||
getBoundResult(res.first()));
|
||||
|
||||
emitOpMatch(tree, 0);
|
||||
|
||||
for (auto &appliedConstraint : pattern.getConstraints()) {
|
||||
|
@ -472,9 +461,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
|
||||
<< ";\n";
|
||||
} else {
|
||||
os.indent(4) << "Value* " << fieldName << ";\n";
|
||||
os.indent(4) << "Value *" << fieldName << ";\n";
|
||||
}
|
||||
}
|
||||
for (const auto &result : pattern.getSourcePatternBoundOps()) {
|
||||
os.indent(4) << "Operation *" << result.getKey() << ";\n";
|
||||
}
|
||||
os << " };\n";
|
||||
|
||||
emitMatchMethod(tree);
|
||||
|
@ -564,9 +556,9 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
|
|||
}
|
||||
|
||||
auto name = tree.getArgName(0);
|
||||
pattern.ensureArgBoundInSourcePattern(name);
|
||||
pattern.ensureBoundInSourcePattern(name);
|
||||
|
||||
return getBoundArgument(name).str();
|
||||
return getBoundSymbol(name).str();
|
||||
}
|
||||
|
||||
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
|
||||
|
@ -587,8 +579,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
|||
auto enumCase = leaf.getAsEnumAttrCase();
|
||||
return handleConstantAttr(enumCase, enumCase.getSymbol());
|
||||
}
|
||||
pattern.ensureArgBoundInSourcePattern(argName);
|
||||
std::string result = getBoundArgument(argName).str();
|
||||
pattern.ensureBoundInSourcePattern(argName);
|
||||
std::string result = getBoundSymbol(argName).str();
|
||||
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
||||
return result;
|
||||
}
|
||||
|
@ -758,7 +750,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
|
|||
bool printingAttr = false;
|
||||
for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) {
|
||||
auto name = resultTree.getArgName(i);
|
||||
pattern.ensureArgBoundInSourcePattern(name);
|
||||
pattern.ensureBoundInSourcePattern(name);
|
||||
const auto &val = boundedValues.find(name);
|
||||
if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
|
||||
os << "}, {";
|
||||
|
@ -767,7 +759,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
|
|||
}
|
||||
if (!first)
|
||||
os << ",";
|
||||
os << getBoundArgument(name);
|
||||
os << getBoundSymbol(name);
|
||||
first = false;
|
||||
}
|
||||
if (!printingAttr)
|
||||
|
|
Loading…
Reference in New Issue