[TableGen] Support benefit score in pattern definition.

A integer number can be specified in the pattern definition and used as the
adjustment to the default benefit score in the generated rewrite pattern C++
definition.

PiperOrigin-RevId: 240994192
This commit is contained in:
Feng Liu 2019-03-29 09:36:09 -07:00 committed by jpienaar
parent 094ca64ab0
commit 5303587448
5 changed files with 55 additions and 8 deletions

View File

@ -716,6 +716,9 @@ class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
// Pattern definitions
//===----------------------------------------------------------------------===//
// Marker used to identify the delta value added to the default benefit value.
def addBenefit;
// Base class for op+ -> op+ rewrite rules. These allow declaratively
// specifying rewrite rules.
//
@ -729,18 +732,25 @@ class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
// the result pattern, `arg*` can be used to refer to a previously bound name,
// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
// be nested DAG node.
class Pattern<dag source, list<dag> results, list<dag> preds = []> {
class Pattern<dag source, list<dag> results, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
list<dag> resultPatterns = results;
// Multi-entity constraints. Each constraint here involves multiple entities
// matched in source pattern and places further constraints on them as a
// whole.
list<dag> constraints = preds;
// The delta value added to the default benefit value. The default value is
// the number of ops in the source pattern. The rule with the highest final
// benefit value will be applied first if there are multiple rules matches.
// This delta value can be either positive or negative.
dag benefitDelta = benefitAdded;
}
// Form of a pattern which produces a single result.
class Pat<dag pattern, dag result, list<dag> preds = []> :
Pattern<pattern, [result], preds>;
class Pat<dag pattern, dag result, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> :
Pattern<pattern, [result], preds, benefitAdded>;
// Attribute transformation. This is the base class to specify a transformation
// of matched attributes. Used on the output attribute of a rewrite rule.

View File

@ -220,6 +220,9 @@ public:
// Returns the constraints.
std::vector<AppliedConstraint> getConstraints() const;
// Returns the benefit score of the pattern.
int getBenefit() const;
private:
// The TableGen definition of this pattern.
const llvm::Record &def;

View File

@ -257,3 +257,13 @@ std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
}
return ret;
}
int tblgen::Pattern::getBenefit() const {
// The default benefit value is a heristic with number of ops in the source
// pattern.
const int defaultBenefit = getSourcePattern().getNumOps();
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0)))
PrintFatalError(def.getLoc(), "The 'AddedBenefit' can only be an integer");
return defaultBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
}

View File

@ -0,0 +1,28 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def IfEqual : Constraint<CPred<"<notused>">>;
// Define ops to rewrite.
def U: Type<CPred<"true">, "U">;
def X_AddOp : Op<"x.add"> {
let arguments = (ins U, U);
}
def Y_AddOp : Op<"y.add"> {
let arguments = (ins U, U, U);
}
def Z_AddOp : Op<"z.add"> {
let arguments = (ins U);
}
// Define rewrite patterns.
def : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
// CHECK-LABEL: struct GeneratedConvert0
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("x.add", 2, context) {}
def : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
// CHECK-LABEL: struct GeneratedConvert1
// CHECK: GeneratedConvert1(MLIRContext *context) : RewritePattern("x.add", 101, context) {}

View File

@ -340,17 +340,13 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern
DagNode tree = pattern.getSourcePattern();
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
// moment, revise.
unsigned benefit = tree.getNumOps();
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
// Emit RewritePattern for Pattern.
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, benefit)
rewriteName, rootName, pattern.getBenefit())
<< "\n";
// Emit matched state.