forked from OSchip/llvm-project
[TableGen] Support benefit score in pattern definition.
A integer number can be specified in the pattern definition and used as the adjustment to the default benefit score in the generated rewrite pattern C++ definition. PiperOrigin-RevId: 240994192
This commit is contained in:
parent
094ca64ab0
commit
5303587448
|
@ -716,6 +716,9 @@ class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
|
|||
// Pattern definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Marker used to identify the delta value added to the default benefit value.
|
||||
def addBenefit;
|
||||
|
||||
// Base class for op+ -> op+ rewrite rules. These allow declaratively
|
||||
// specifying rewrite rules.
|
||||
//
|
||||
|
@ -729,18 +732,25 @@ class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
|
|||
// the result pattern, `arg*` can be used to refer to a previously bound name,
|
||||
// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
|
||||
// be nested DAG node.
|
||||
class Pattern<dag source, list<dag> results, list<dag> preds = []> {
|
||||
class Pattern<dag source, list<dag> results, list<dag> preds = [],
|
||||
dag benefitAdded = (addBenefit 0)> {
|
||||
dag sourcePattern = source;
|
||||
list<dag> resultPatterns = results;
|
||||
// Multi-entity constraints. Each constraint here involves multiple entities
|
||||
// matched in source pattern and places further constraints on them as a
|
||||
// whole.
|
||||
list<dag> constraints = preds;
|
||||
// The delta value added to the default benefit value. The default value is
|
||||
// the number of ops in the source pattern. The rule with the highest final
|
||||
// benefit value will be applied first if there are multiple rules matches.
|
||||
// This delta value can be either positive or negative.
|
||||
dag benefitDelta = benefitAdded;
|
||||
}
|
||||
|
||||
// Form of a pattern which produces a single result.
|
||||
class Pat<dag pattern, dag result, list<dag> preds = []> :
|
||||
Pattern<pattern, [result], preds>;
|
||||
class Pat<dag pattern, dag result, list<dag> preds = [],
|
||||
dag benefitAdded = (addBenefit 0)> :
|
||||
Pattern<pattern, [result], preds, benefitAdded>;
|
||||
|
||||
// Attribute transformation. This is the base class to specify a transformation
|
||||
// of matched attributes. Used on the output attribute of a rewrite rule.
|
||||
|
|
|
@ -220,6 +220,9 @@ public:
|
|||
// Returns the constraints.
|
||||
std::vector<AppliedConstraint> getConstraints() const;
|
||||
|
||||
// Returns the benefit score of the pattern.
|
||||
int getBenefit() const;
|
||||
|
||||
private:
|
||||
// The TableGen definition of this pattern.
|
||||
const llvm::Record &def;
|
||||
|
|
|
@ -257,3 +257,13 @@ std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int tblgen::Pattern::getBenefit() const {
|
||||
// The default benefit value is a heristic with number of ops in the source
|
||||
// pattern.
|
||||
const int defaultBenefit = getSourcePattern().getNumOps();
|
||||
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
|
||||
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0)))
|
||||
PrintFatalError(def.getLoc(), "The 'AddedBenefit' can only be an integer");
|
||||
return defaultBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def IfEqual : Constraint<CPred<"<notused>">>;
|
||||
|
||||
// Define ops to rewrite.
|
||||
def U: Type<CPred<"true">, "U">;
|
||||
def X_AddOp : Op<"x.add"> {
|
||||
let arguments = (ins U, U);
|
||||
}
|
||||
def Y_AddOp : Op<"y.add"> {
|
||||
let arguments = (ins U, U, U);
|
||||
}
|
||||
def Z_AddOp : Op<"z.add"> {
|
||||
let arguments = (ins U);
|
||||
}
|
||||
|
||||
// Define rewrite patterns.
|
||||
def : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert0
|
||||
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("x.add", 2, context) {}
|
||||
|
||||
def : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert1
|
||||
// CHECK: GeneratedConvert1(MLIRContext *context) : RewritePattern("x.add", 101, context) {}
|
|
@ -340,17 +340,13 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
// Get the DAG tree for the source pattern
|
||||
DagNode tree = pattern.getSourcePattern();
|
||||
|
||||
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
|
||||
// moment, revise.
|
||||
unsigned benefit = tree.getNumOps();
|
||||
|
||||
const Operator &rootOp = pattern.getSourceRootOp();
|
||||
auto rootName = rootOp.getOperationName();
|
||||
|
||||
// Emit RewritePattern for Pattern.
|
||||
os << formatv(R"(struct {0} : public RewritePattern {
|
||||
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
|
||||
rewriteName, rootName, benefit)
|
||||
rewriteName, rootName, pattern.getBenefit())
|
||||
<< "\n";
|
||||
|
||||
// Emit matched state.
|
||||
|
|
Loading…
Reference in New Issue