[TableGen] Assign created ops to variables and rewrite with PatternRewriter::replaceOp()

Previously we were using PatternRewrite::replaceOpWithNewOp() to both create the new op
inline and rewrite the matched op. That does not work well if we want to generate multiple
ops in a sequence. To support that, this CL changed to assign each newly created op to a
separate variable.

This CL also refactors how PatternEmitter performs the directive dispatch logic.

PiperOrigin-RevId: 233206819
This commit is contained in:
Lei Zhang 2019-02-09 06:36:23 -08:00 committed by jpienaar
parent d7e6b33e93
commit a57b398906
4 changed files with 116 additions and 42 deletions

View File

@ -183,6 +183,10 @@ public:
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
// Returns true if an argument with the given `name` is bound in source
// pattern.
bool isArgBoundInSourcePattern(llvm::StringRef name) const;
// Checks whether an argument with the given `name` is bound in source
// pattern. Prints fatal error if not; does nothing otherwise.
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;

View File

@ -183,9 +183,13 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
bool tblgen::Pattern::isArgBoundInSourcePattern(llvm::StringRef name) const {
return boundArguments.find(name) != boundArguments.end();
}
void tblgen::Pattern::ensureArgBoundInSourcePattern(
llvm::StringRef name) const {
if (boundArguments.find(name) == boundArguments.end())
if (!isArgBoundInSourcePattern(name))
PrintFatalError(def.getLoc(),
Twine("referencing unbound variable '") + name + "'");
}

View File

@ -24,6 +24,6 @@ def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
// CHECK: PatternMatchResult match(Instruction *
// CHECK: void rewrite(Instruction *op, std::unique_ptr<PatternState>
// CHECK: PatternRewriter &rewriter)
// CHECK: rewriter.replaceOpWithNewOp<AddOp>(op, op->getResult(0)->getType()
// CHECK: rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType()
// CHECK: void populateWithGenerated
// CHECK: patterns->push_back(std::make_unique<GeneratedConvert0>(context))

View File

@ -52,8 +52,7 @@ public:
raw_ostream &os);
private:
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {}
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
// Emits the mlir::RewritePattern struct named `rewriteName`.
void emit(StringRef rewriteName);
@ -64,14 +63,6 @@ private:
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits the C++ statement to replace the matched DAG with a new op.
void emitReplaceOpWithNewOp(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with an existing value.
void emitReplaceWithExistingValue(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with a native C++ built
// value.
void emitReplaceWithNativeBuilder(DagNode resultTree);
// Emits the value of constant attribute to `os`.
void emitConstantAttr(tblgen::ConstantAttr constAttr);
@ -86,6 +77,30 @@ private:
// `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
// Returns a unique name for an value of the given `op`.
std::string getUniqueValueName(const Operator *op);
// Entry point for handling a rewrite pattern rooted at `resultTree` and
// dispatches to concrete handlers. The given tree is the `resultIndex`-th
// argument of the enclosing DAG.
std::string handleRewritePattern(DagNode resultTree, int resultIndex,
int depth, llvm::StringRef treeName = "");
// Emits the C++ statement to replace the matched DAG with a native C++ built
// value.
std::string emitReplaceWithNativeBuilder(DagNode resultTree);
// Returns the C++ expression referencing the old value serving as the
// replacement.
std::string handleReplaceWithValue(DagNode tree);
// Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If `treeName` is not
// empty, the created op will be assigned to a variable of the given
// `treeName`. Otherwise, a unique name will be used as the result value name.
std::string emitOpCreate(DagNode tree, int resultIndex, int depth,
llvm::StringRef treeName = "");
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
@ -95,10 +110,17 @@ private:
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
tblgen::Pattern pattern;
// The next unused ID for newly created values
unsigned nextValueId;
raw_ostream &os;
};
} // end namespace
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
os(os) {}
void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
auto attr = constAttr.getAttribute();
@ -283,31 +305,77 @@ void PatternEmitter::emitRewriteMethod() {
void rewrite(Instruction *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
auto& s = *static_cast<MatchedState *>(state.get());
auto loc = op->getLoc(); (void)loc;
)";
if (resultTree.isNativeCodeBuilder())
emitReplaceWithNativeBuilder(resultTree);
else if (resultTree.isReplaceWithValue())
emitReplaceWithExistingValue(resultTree);
else
emitReplaceOpWithNewOp(resultTree);
std::string resultValue =
handleRewritePattern(resultTree, /*resultIndex=*/0, /*depth=*/0);
os << " }\n";
os.indent(4) << "rewriter.replaceOp(op, {" << resultValue;
os << "});\n }\n";
}
void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
Operator &resultOp = resultTree.getDialectOp(opMap);
std::string PatternEmitter::getUniqueValueName(const Operator *op) {
return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
}
std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
int resultIndex, int depth,
llvm::StringRef treeName) {
if (resultTree.isNativeCodeBuilder())
return emitReplaceWithNativeBuilder(resultTree);
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree);
return emitOpCreate(resultTree, resultIndex, depth, treeName);
}
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
assert(tree.isReplaceWithValue());
if (tree.getNumArgs() != 1) {
PrintFatalError(
loc, "replaceWithValue directive must take exactly one argument");
}
auto name = tree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
// We are referencing some bound value in the source pattern. Those values are
// grouped into a transient struct named as `s`.
return std::string("s.") + name.str();
}
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
int depth, StringRef treeName) {
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs =
resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
os << formatv(R"(
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.getCppClassName());
if (numOpArgs != resultTree.getNumArgs()) {
if (numOpArgs != tree.getNumArgs()) {
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
resultOp.getOperationName(),
resultTree.getNumArgs(), numOpArgs));
resultOp.getOperationName(), tree.getNumArgs(),
numOpArgs));
}
std::string resultValue =
treeName.empty() ? getUniqueValueName(&resultOp) : treeName.str();
// TODO: this is a hack to support various constant ops. We are assuming
// all of them have no operands and one attribute here. Figure out a better
// way to do this.
if (resultOp.getNumOperands() == 0 &&
resultOp.getNumNativeAttributes() == 1) {
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
resultOp.getQualCppClassName());
} else {
std::string resultType = formatv("op->getResult({0})", resultIndex).str();
os.indent(4) << formatv(
"auto {0} = rewriter.create<{1}>(loc, {2}->getType()", resultValue,
resultOp.getQualCppClassName(), resultType);
}
// Create the builder call for the result.
@ -317,7 +385,7 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
// Start each operand on its own line.
(os << ",\n").indent(6);
auto name = resultTree.getArgName(i);
auto name = tree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
@ -327,13 +395,13 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
}
// Add attributes.
for (int e = resultTree.getNumArgs(); i != e; ++i) {
for (int e = tree.getNumArgs(); i != e; ++i) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
auto leaf = resultTree.getArgAsLeaf(i);
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = resultTree.getArgName(i);
auto patArgName = tree.getArgName(i);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(i);
@ -360,20 +428,16 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
}
}
os << "\n );\n";
return resultValue;
}
void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
if (resultTree.getNumArgs() != 1) {
PrintFatalError(loc, "exactly one argument needed in the result pattern");
}
std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
// The variable's name for holding the result of this native builder call
std::string value = formatv("v{0}", nextValueId++).str();
auto name = resultTree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
}
void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
os.indent(4) << resultTree.getNativeCodeBuilder() << "(op, {";
os.indent(4) << "auto " << value << " = " << resultTree.getNativeCodeBuilder()
<< "(op, {";
const auto &boundedValues = pattern.getSourcePatternBoundArgs();
bool first = true;
bool printingAttr = false;
@ -394,6 +458,8 @@ void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
if (!printingAttr)
os << "},{";
os << "}, rewriter);\n";
return value;
}
void PatternEmitter::emit(StringRef rewriteName, Record *p,