[mlir][pdl] Restructure how results are represented.

Up until now, results have been represented as additional results to a pdl.operation. This is fairly clunky, as it mismatches the representation of the rest of the IR constructs(e.g. pdl.operand) and also isn't a viable representation for operations returned by pdl.create_native. This representation also creates much more difficult problems when factoring in support for variadic result groups, optional results, etc. To resolve some of these problems, and simplify adding support for variable length results, this revision extracts the representation for results out of pdl.operation in the form of a new `pdl.result` operation. This operation returns the result of an operation at a given index, e.g.:

```
%root = pdl.operation ...
%result = pdl.result 0 of %root
```

Differential Revision: https://reviews.llvm.org/D95719
This commit is contained in:
River Riddle 2021-03-16 13:11:07 -07:00
parent 1bc8f5fbb4
commit 242762c9a3
10 changed files with 359 additions and 394 deletions

View File

@ -48,7 +48,7 @@ def PDL_Dialect : Dialect {
%resultType = pdl.type
%inputOperand = pdl.operand
%root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
%root = pdl.operation "foo.op"(%inputOperand) -> %resultType
pdl.rewrite %root {
pdl.replace %root with (%inputOperand)
}

View File

@ -177,7 +177,7 @@ def PDL_OperandOp : PDL_Op<"operand", [HasParent<"pdl::PatternOp">]> {
let description = [{
`pdl.operand` operations capture external operand edges into an operation
node that originate from operations or block arguments not otherwise
specified within the pattern (e.g. via `pdl.operation`). These operations
specified within the pattern (e.g. via `pdl.result`). These operations
define individual operands of a given operation. A `pdl.operand` may
partially constrain an operand by specifying an expected value type
(via a `pdl.type` operation).
@ -223,8 +223,8 @@ def PDL_OperationOp
`pdl.operation`s are composed of a name, and a set of attribute, operand,
and result type values, that map to what those that would be on a
constructed instance of that operation. The results of a `pdl.operation` are
a handle to the operation itself, and a handle to each of the operation
result values.
a handle to the operation itself. Handles to the results of the operation
can be extracted via `pdl.result`.
When used within a matching context, the name of the operation may be
omitted.
@ -241,7 +241,7 @@ def PDL_OperationOp
```mlir
// Define an instance of a `foo.op` operation.
%op, %results:4 = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type
%op = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type
```
}];
@ -250,8 +250,13 @@ def PDL_OperationOp
Variadic<PDL_Attribute>:$attributes,
StrArrayAttr:$attributeNames,
Variadic<PDL_Type>:$types);
let results = (outs PDL_Operation:$op,
Variadic<PDL_Value>:$results);
let results = (outs PDL_Operation:$op);
let assemblyFormat = [{
($name^)? (`(` $operands^ `)`)?
custom<OperationOpAttributes>($attributes, $attributeNames)
(`->` $types^)? attr-dict
}];
let builders = [
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name,
CArg<"ValueRange", "llvm::None">:$operandValues,
@ -259,10 +264,9 @@ def PDL_OperationOp
CArg<"ValueRange", "llvm::None">:$attrValues,
CArg<"ValueRange", "llvm::None">:$resultTypes), [{
auto nameAttr = name ? StringAttr() : $_builder.getStringAttr(*name);
build($_builder, $_state, $_builder.getType<OperationType>(), {}, nameAttr,
build($_builder, $_state, $_builder.getType<OperationType>(), nameAttr,
operandValues, attrValues, $_builder.getStrArrayAttr(attrNames),
resultTypes);
$_state.types.append(resultTypes.size(), $_builder.getType<ValueType>());
}]>,
];
let extraClassDeclaration = [{
@ -293,7 +297,7 @@ def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> {
pdl.pattern : benefit(1) {
%resultType = pdl.type
%inputOperand = pdl.operand
%root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType)
%root = pdl.operation "foo.op"(%inputOperand) -> (%resultType)
pdl.rewrite %root {
pdl.replace %root with (%inputOperand)
}
@ -368,6 +372,39 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
}];
}
//===----------------------------------------------------------------------===//
// pdl::ResultOp
//===----------------------------------------------------------------------===//
def PDL_ResultOp : PDL_Op<"result"> {
let summary = "Extract a result from an operation";
let description = [{
`pdl.result` operations extract result edges from an operation node within
a pattern or rewrite region. The provided index is zero-based, and
represents the concrete result to extract, i.e. this is not the result index
as defined by the ODS definition of the operation.
Example:
```mlir
// Extract a result:
%operation = pdl.operation ...
%result = pdl.result 1 of %operation
// Imagine the following IR being matched:
%result_0, %result_1 = foo.op ...
// If the example pattern snippet above were matching against `foo.op` in
// the IR snippted, `%result` would correspond to `%result_1`.
```
}];
let arguments = (ins PDL_Operation:$parent, I32Attr:$index);
let results = (outs PDL_Value:$val);
let assemblyFormat = "$index `of` $parent attr-dict";
let verifier = ?;
}
//===----------------------------------------------------------------------===//
// pdl::RewriteOp
//===----------------------------------------------------------------------===//

View File

@ -85,6 +85,9 @@ private:
void generateRewriter(pdl::ReplaceOp replaceOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ResultOp resultOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::TypeOp typeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
@ -457,9 +460,10 @@ SymbolRefAttr PatternLowering::generateRewriter(
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
.Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
pdl::OperationOp, pdl::ReplaceOp, pdl::TypeOp>([&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
});
pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::TypeOp>(
[&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
});
}
}
@ -511,17 +515,15 @@ void PatternLowering::generateRewriter(
operationOp.attributeNames());
rewriteValues[operationOp.op()] = createdOp;
// Make all of the new operation results available.
OperandRange resultTypes = operationOp.types();
for (auto it : llvm::enumerate(operationOp.results())) {
// Generate accesses for any results that have their types constrained.
for (auto it : llvm::enumerate(operationOp.types())) {
Value &type = rewriteValues[it.value()];
if (type)
continue;
Value getResultVal = builder.create<pdl_interp::GetResultOp>(
loc, builder.getType<pdl::ValueType>(), createdOp, it.index());
rewriteValues[it.value()] = getResultVal;
// If any of the types have not been resolved, make those available as well.
Value &type = rewriteValues[resultTypes[it.index()]];
if (!type)
type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
}
}
@ -540,29 +542,41 @@ void PatternLowering::generateRewriter(
void PatternLowering::generateRewriter(
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> replOperands;
// If the replacement was another operation, get its results. `pdl` allows
// for using an operation for simplicitly, but the interpreter isn't as
// user facing.
ValueRange origOperands;
if (Value replOp = replaceOp.replOperation())
origOperands = cast<pdl::OperationOp>(replOp.getDefiningOp()).results();
else
origOperands = replaceOp.replValues();
if (Value replOp = replaceOp.replOperation()) {
pdl::OperationOp op = cast<pdl::OperationOp>(replOp.getDefiningOp());
for (unsigned i = 0, e = op.types().size(); i < e; ++i)
replOperands.push_back(builder.create<pdl_interp::GetResultOp>(
replOp.getLoc(), builder.getType<pdl::ValueType>(),
mapRewriteValue(replOp), i));
} else {
for (Value operand : replaceOp.replValues())
replOperands.push_back(mapRewriteValue(operand));
}
// If there are no replacement values, just create an erase instead.
if (origOperands.empty()) {
if (replOperands.empty()) {
builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.operation()));
return;
}
SmallVector<Value, 4> replOperands;
for (Value operand : origOperands)
replOperands.push_back(mapRewriteValue(operand));
builder.create<pdl_interp::ReplaceOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
}
void PatternLowering::generateRewriter(
pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
resultOp.getLoc(), builder.getType<pdl::ValueType>(),
mapRewriteValue(resultOp.parent()), resultOp.index());
}
void PatternLowering::generateRewriter(
pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
@ -602,8 +616,8 @@ void PatternLowering::generateOperationResultTypeRewriter(
bool hasTypeInference = op.hasTypeInference();
auto resultTypeValues = op.types();
types.reserve(resultTypeValues.size());
for (auto it : llvm::enumerate(op.results())) {
Value result = it.value(), resultType = resultTypeValues[it.index()];
for (auto it : llvm::enumerate(resultTypeValues)) {
Value resultType = it.value();
// Check for an already translated value.
if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
@ -633,16 +647,11 @@ void PatternLowering::generateOperationResultTypeRewriter(
if ((replacedOp = getReplacedOperationFrom(use)))
break;
fullReplacedOperation = replacedOp;
assert(fullReplacedOperation &&
"expected replaced op to infer a result type from");
} else {
replacedOp = fullReplacedOperation.getValue();
}
// Infer from the result, as there was no fully replaced op.
if (!replacedOp) {
for (OpOperand &use : result.getUses())
if ((replacedOp = getReplacedOperationFrom(use)))
break;
assert(replacedOp && "expected replaced op to infer a result type from");
}
auto replOpOp = cast<pdl::OperationOp>(replacedOp);
types.push_back(mapRewriteValue(replOpOp.types()[it.index()]));

View File

@ -433,7 +433,7 @@ public:
Position *getRoot() { return OperationPosition::getRoot(uniquer); }
/// Returns the parent position defining the value held by the given operand.
Position *getParent(OperandPosition *p) {
OperationPosition *getParent(OperandPosition *p) {
std::vector<unsigned> index = p->getIndex();
index.push_back(p->getOperandNumber());
return OperationPosition::get(uniquer, index);

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
@ -20,11 +21,108 @@ using namespace mlir::pdl_to_pdl_interp;
// Predicate List Building
//===----------------------------------------------------------------------===//
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos);
/// Compares the depths of two positions.
static bool comparePosDepth(Position *lhs, Position *rhs) {
return lhs->getIndex().size() < rhs->getIndex().size();
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type or value, add a constraint.
if (Value type = attr.type())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.valueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
OperandPosition *pos) {
assert(val.getType().isa<pdl::ValueType>() && "expected value type");
// Prevent traversal into a null value.
predList.emplace_back(pos, builder.getIsNotNull());
// If this is a typed operand, add a type constraint.
if (auto in = val.getDefiningOp<pdl::OperandOp>()) {
if (Value type = in.type())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
// Otherwise, recurse into a result node.
} else if (auto resultOp = val.getDefiningOp<pdl::ResultOp>()) {
OperationPosition *parentPos = builder.getParent(pos);
Position *resultPos = builder.getResult(parentPos, resultOp.index());
predList.emplace_back(parentPos, builder.getIsNotNull());
predList.emplace_back(resultPos, builder.getEqualTo(pos));
getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos);
}
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
OperationPosition *pos) {
assert(val.getType().isa<pdl::OperationType>() && "expected operation");
pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
OperationPosition *opPos = cast<OperationPosition>(pos);
// Ensure getDefiningOp returns a non-null operation.
if (!opPos->isRoot())
predList.emplace_back(pos, builder.getIsNotNull());
// Check that this is the correct root operation.
if (Optional<StringRef> opName = op.name())
predList.emplace_back(pos, builder.getOperationName(*opName));
// Check that the operation has the proper number of operands and results.
OperandRange operands = op.operands();
OperandRange types = op.types();
predList.emplace_back(pos, builder.getOperandCount(operands.size()));
predList.emplace_back(pos, builder.getResultCount(types.size()));
// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
getTreePredicates(
predList, std::get<1>(it), builder, inputs,
builder.getAttribute(opPos,
std::get<0>(it).cast<StringAttr>().getValue()));
}
for (auto operandIt : llvm::enumerate(operands)) {
getTreePredicates(predList, operandIt.value(), builder, inputs,
builder.getOperand(opPos, operandIt.index()));
}
for (auto &resultIt : llvm::enumerate(types)) {
auto *resultPos = builder.getResult(pos, resultIt.index());
predList.emplace_back(resultPos, builder.getIsNotNull());
getTreePredicates(predList, resultIt.value(), builder, inputs,
builder.getType(resultPos));
}
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
TypePosition *pos) {
assert(val.getType().isa<pdl::TypeType>() && "expected value type");
pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
// Check for a constraint on a constant type.
if (Optional<Type> type = typeOp.type())
predList.emplace_back(pos, builder.getTypeConstraint(*type));
}
/// Collect the tree predicates anchored at the given value.
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
@ -32,139 +130,72 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Position *pos) {
// Make sure this input value is accessible to the rewrite.
auto it = inputs.try_emplace(val, pos);
// If this is an input value that has been visited in the tree, add a
// constraint to ensure that both instances refer to the same value.
if (!it.second &&
isa<pdl::AttributeOp, pdl::OperandOp, pdl::TypeOp>(val.getDefiningOp())) {
auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth);
predList.emplace_back(minMaxPositions.second,
builder.getEqualTo(minMaxPositions.first));
if (!it.second) {
// If this is an input value that has been visited in the tree, add a
// constraint to ensure that both instances refer to the same value.
if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperationOp, pdl::TypeOp>(
val.getDefiningOp())) {
auto minMaxPositions =
std::minmax(pos, it.first->second, comparePosDepth);
predList.emplace_back(minMaxPositions.second,
builder.getEqualTo(minMaxPositions.first));
}
return;
}
// Check for a per-position predicate to apply.
switch (pos->getKind()) {
case Predicates::AttributePos: {
assert(val.getType().isa<pdl::AttributeType>() &&
"expected attribute type");
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type, add a type constraint.
if (Value type = attr.type()) {
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
// Check for a constant value of the attribute.
} else if (Optional<Attribute> value = attr.value()) {
predList.emplace_back(pos, builder.getAttributeConstraint(*value));
}
break;
}
case Predicates::OperandPos: {
assert(val.getType().isa<pdl::ValueType>() && "expected value type");
// Prevent traversal into a null value.
predList.emplace_back(pos, builder.getIsNotNull());
// If this is a typed operand, add a type constraint.
if (auto in = val.getDefiningOp<pdl::OperandOp>()) {
if (Value type = in.type()) {
getTreePredicates(predList, type, builder, inputs,
builder.getType(pos));
}
// Otherwise, recurse into the parent node.
} else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) {
getTreePredicates(predList, parentOp.op(), builder, inputs,
builder.getParent(cast<OperandPosition>(pos)));
}
break;
}
case Predicates::OperationPos: {
assert(val.getType().isa<pdl::OperationType>() && "expected operation");
pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
OperationPosition *opPos = cast<OperationPosition>(pos);
// Ensure getDefiningOp returns a non-null operation.
if (!opPos->isRoot())
predList.emplace_back(pos, builder.getIsNotNull());
// Check that this is the correct root operation.
if (Optional<StringRef> opName = op.name())
predList.emplace_back(pos, builder.getOperationName(*opName));
// Check that the operation has the proper number of operands and results.
OperandRange operands = op.operands();
ResultRange results = op.results();
predList.emplace_back(pos, builder.getOperandCount(operands.size()));
predList.emplace_back(pos, builder.getResultCount(results.size()));
// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
getTreePredicates(
predList, std::get<1>(it), builder, inputs,
builder.getAttribute(opPos,
std::get<0>(it).cast<StringAttr>().getValue()));
}
for (auto operandIt : llvm::enumerate(operands))
getTreePredicates(predList, operandIt.value(), builder, inputs,
builder.getOperand(opPos, operandIt.index()));
// Only recurse into results that are not referenced in the source tree.
for (auto resultIt : llvm::enumerate(results)) {
getTreePredicates(predList, resultIt.value(), builder, inputs,
builder.getResult(opPos, resultIt.index()));
}
break;
}
case Predicates::ResultPos: {
assert(val.getType().isa<pdl::ValueType>() && "expected value type");
pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp());
// Prevent traversing a null value.
predList.emplace_back(pos, builder.getIsNotNull());
// Traverse the type constraint.
unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber();
getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs,
builder.getType(pos));
break;
}
case Predicates::TypePos: {
assert(val.getType().isa<pdl::TypeType>() && "expected value type");
pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
// Check for a constraint on a constant type.
if (Optional<Type> type = typeOp.type())
predList.emplace_back(pos, builder.getTypeConstraint(*type));
break;
}
default:
llvm_unreachable("unknown position kind");
}
TypeSwitch<Position *>(pos)
.Case<AttributePosition, OperandPosition, OperationPosition,
TypePosition>([&](auto *derivedPos) {
getTreePredicates(predList, val, builder, inputs, derivedPos);
})
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
}
/// Collect all of the predicates related to constraints within the given
/// pattern operation.
static void collectConstraintPredicates(
pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) {
for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) {
OperandRange arguments = op.args();
ArrayAttr parameters = op.constParamsAttr();
static void getConstraintPredicates(pdl::ApplyConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
OperandRange arguments = op.args();
ArrayAttr parameters = op.constParamsAttr();
std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
for (Value arg : arguments)
allPositions.push_back(inputs.lookup(arg));
std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
for (Value arg : arguments)
allPositions.push_back(inputs.lookup(arg));
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.name(), std::move(allPositions), parameters);
predList.emplace_back(pos, pred);
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.name(), std::move(allPositions), parameters);
predList.emplace_back(pos, pred);
}
static void getResultPredicates(pdl::ResultOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&resultPos = inputs[op];
if (resultPos)
return;
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
resultPos = builder.getResult(parentPos, op.index());
predList.emplace_back(resultPos, builder.getIsNotNull());
}
/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
if (auto constraintOp = dyn_cast<pdl::ApplyConstraintOp>(&op))
getConstraintPredicates(constraintOp, predList, builder, inputs);
else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op))
getResultPredicates(resultOp, predList, builder, inputs);
}
}
@ -176,7 +207,7 @@ static void buildPredicateList(pdl::PatternOp pattern,
DenseMap<Value, Position *> &valueToPosition) {
getTreePredicates(predList, pattern.getRewriter().root(), builder,
valueToPosition, builder.getRoot());
collectConstraintPredicates(pattern, predList, builder, valueToPosition);
getNonTreePredicates(pattern, predList, builder, valueToPosition);
}
//===----------------------------------------------------------------------===//

View File

@ -28,21 +28,36 @@ void PDLDialect::initialize() {
registerTypes();
}
//===----------------------------------------------------------------------===//
// PDL Operations
//===----------------------------------------------------------------------===//
/// Returns true if the given operation is used by a "binding" pdl operation
/// within the main matcher body of a `pdl.pattern`.
static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) {
for (Operation *user : op->getUsers()) {
if (user->getBlock() != matcherBlock)
continue;
if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
return true;
// A result by itself is not binding, it must also be bound.
if (isa<ResultOp>(user) && hasBindingUseInMatcher(user, matcherBlock))
return true;
}
return false;
}
/// Returns success if the given operation is used by a "binding" pdl operation
/// within the main matcher body of a `pdl.pattern`. On failure, emits an error
/// with the given context message.
static LogicalResult
verifyHasBindingUseInMatcher(Operation *op,
StringRef bindableContextStr = "`pdl.operation`") {
// If the pattern is not a pattern, there is nothing to do.
if (!isa<PatternOp>(op->getParentOp()))
return success();
Block *matcherBlock = op->getBlock();
for (Operation *user : op->getUsers()) {
if (user->getBlock() != matcherBlock)
continue;
if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
return success();
}
if (hasBindingUseInMatcher(op, op->getBlock()))
return success();
return op->emitOpError()
<< "expected a bindable (i.e. " << bindableContextStr
<< ") user when defined in the matcher body of a `pdl.pattern`";
@ -86,37 +101,12 @@ static LogicalResult verify(OperandOp op) {
// pdl::OperationOp
//===----------------------------------------------------------------------===//
static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
static ParseResult parseOperationOpAttributes(
OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
// Parse the optional operation name.
bool startsWithOperands = succeeded(p.parseOptionalLParen());
bool startsWithAttributes =
!startsWithOperands && succeeded(p.parseOptionalLBrace());
bool startsWithOpName = false;
if (!startsWithAttributes && !startsWithOperands) {
StringAttr opName;
OptionalParseResult opNameResult =
p.parseOptionalAttribute(opName, "name", state.attributes);
startsWithOpName = opNameResult.hasValue();
if (startsWithOpName && failed(*opNameResult))
return failure();
}
// Parse the operands.
SmallVector<OpAsmParser::OperandType, 4> operands;
if (startsWithOperands ||
(!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
if (p.parseOperandList(operands) || p.parseRParen() ||
p.resolveOperands(operands, builder.getType<ValueType>(),
state.operands))
return failure();
}
// Parse the attributes.
SmallVector<Attribute, 4> attrNames;
if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
SmallVector<OpAsmParser::OperandType, 4> attrOps;
if (succeeded(p.parseOptionalLBrace())) {
do {
StringAttr nameAttr;
OpAsmParser::OperandType operand;
@ -124,68 +114,29 @@ static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOps.push_back(operand);
attrOperands.push_back(operand);
} while (succeeded(p.parseOptionalComma()));
if (p.parseRBrace() ||
p.resolveOperands(attrOps, builder.getType<AttributeType>(),
state.operands))
if (p.parseRBrace())
return failure();
}
state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
state.addTypes(builder.getType<OperationType>());
// Parse the result types.
SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
if (succeeded(p.parseOptionalArrow())) {
if (p.parseOperandList(opResultTypes) ||
p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
state.operands))
return failure();
state.types.append(opResultTypes.size(), builder.getType<ValueType>());
}
if (p.parseOptionalAttrDict(state.attributes))
return failure();
int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(attrNames.size()),
static_cast<int32_t>(opResultTypes.size())};
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr(operandSegmentSizes));
attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}
static void print(OpAsmPrinter &p, OperationOp op) {
p << "pdl.operation ";
if (Optional<StringRef> name = op.name())
p << '"' << *name << '"';
auto operandValues = op.operands();
if (!operandValues.empty())
p << '(' << operandValues << ')';
// Emit the optional attributes.
ArrayAttr attrNames = op.attributeNames();
if (!attrNames.empty()) {
Operation::operand_range attrArgs = op.attributes();
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}
// Print the result type constraints of the operation.
if (!op.results().empty())
p << " -> " << op.types();
p.printOptionalAttrDict(op->getAttrs(),
{"attributeNames", "name", "operand_segment_sizes"});
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
OperandRange attrArgs,
ArrayAttr attrNames) {
if (attrNames.empty())
return;
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}
/// Verifies that the result types of this operation, defined within a
/// `pdl.rewrite`, can be inferred.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
ResultRange opResults,
OperandRange resultTypes) {
// Functor that returns if the given use can be used to infer a type.
Block *rewriterBlock = op->getBlock();
@ -207,8 +158,8 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
return success();
// Otherwise, make sure each of the types can be inferred.
for (int i : llvm::seq<int>(0, opResults.size())) {
Operation *resultTypeOp = resultTypes[i].getDefiningOp();
for (auto it : llvm::enumerate(resultTypes)) {
Operation *resultTypeOp = it.value().getDefiningOp();
assert(resultTypeOp && "expected valid result type operation");
// If the op was defined by a `create_native`, it is guaranteed to be
@ -229,14 +180,11 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
continue;
// Otherwise, check to see if any uses of the result can infer the type.
if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
continue;
return op
.emitOpError("must have inferable or constrained result types when "
"nested within `pdl.rewrite`")
.attachNote()
.append("result type #", i, " was not constrained");
.append("result type #", it.index(), " was not constrained");
}
return success();
}
@ -256,19 +204,10 @@ static LogicalResult verify(OperationOp op) {
<< " values";
}
OperandRange resultTypes = op.types();
auto opResults = op.results();
if (resultTypes.size() != opResults.size()) {
return op.emitOpError() << "expected the same number of result values and "
"result type constraints, got "
<< opResults.size() << " results and "
<< resultTypes.size() << " constraints";
}
// If the operation is within a rewrite body and doesn't have type inference,
// ensure that the result types can be resolved.
if (isWithinRewrite && !op.hasTypeInference()) {
if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
if (failed(verifyResultTypesAreInferrable(op, op.types())))
return failure();
}
@ -341,37 +280,9 @@ Optional<StringRef> PatternOp::getRootKind() {
//===----------------------------------------------------------------------===//
static LogicalResult verify(ReplaceOp op) {
auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
auto sourceOpResults = sourceOp.results();
auto replValues = op.replValues();
if (Value replOpVal = op.replOperation()) {
auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
auto replOpResults = replOp.results();
if (sourceOpResults.size() != replOpResults.size()) {
return op.emitOpError()
<< "expected source operation to have the same number of results "
"as the replacement operation, replacement operation provided "
<< replOpResults.size() << " but expected "
<< sourceOpResults.size();
}
if (!replValues.empty()) {
return op.emitOpError() << "expected no replacement values to be provided"
" when the replacement operation is present";
}
return success();
}
if (sourceOpResults.size() != replValues.size()) {
return op.emitOpError()
<< "expected source operation to have the same number of results "
"as the provided replacement values, found "
<< replValues.size() << " replacement values but expected "
<< sourceOpResults.size();
}
if (op.replOperation() && !op.replValues().empty())
return op.emitOpError() << "expected no replacement values to be provided"
" when the replacement operation is present";
return success();
}

View File

@ -63,15 +63,16 @@ module @constraints {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]] : !pdl.value, !pdl.value)
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]]
pdl.pattern : benefit(1) {
%input0 = pdl.operand
%input1 = pdl.operand
pdl.apply_constraint "multi_constraint"[true](%input0, %input1 : !pdl.value, !pdl.value)
%root = pdl.operation(%input0, %input1)
%result0 = pdl.result 0 of %root
pdl.apply_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
pdl.rewrite %root with "rewriter"
}
}
@ -107,19 +108,52 @@ module @results {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.check_result_count of %[[ROOT]] is 2
// Get the input and check the type.
// Get the result and check the type.
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value
// CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
// CHECK-DAG: pdl_interp.check_type %[[RESULT_TYPE]] is i32
// Get the second operand and check that it is equal to the first.
// CHECK-DAG: %[[RESULT1:.*]] = pdl_interp.get_result 1 of %[[ROOT]]
// CHECK-NOT: pdl_interp.get_value_type of %[[RESULT1]]
// The second result doesn't have any constraints, so we don't generate an
// access for it.
// CHECK-NOT: pdl_interp.get_result 1 of %[[ROOT]]
pdl.pattern : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%root = pdl.operation -> %type1, %type2
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @results_as_operands
module @results_as_operands {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// Get the first result and check it matches the first operand.
// CHECK-DAG: %[[OPERAND_0:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[DEF_OP_0:.*]] = pdl_interp.get_defining_op of %[[OPERAND_0]]
// CHECK-DAG: %[[RESULT_0:.*]] = pdl_interp.get_result 0 of %[[DEF_OP_0]]
// CHECK-DAG: pdl_interp.are_equal %[[RESULT_0]], %[[OPERAND_0]]
// Get the second result and check it matches the second operand.
// CHECK-DAG: %[[OPERAND_1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK-DAG: %[[DEF_OP_1:.*]] = pdl_interp.get_defining_op of %[[OPERAND_1]]
// CHECK-DAG: %[[RESULT_1:.*]] = pdl_interp.get_result 1 of %[[DEF_OP_1]]
// CHECK-DAG: pdl_interp.are_equal %[[RESULT_1]], %[[OPERAND_1]]
// Check that the parent operation of both results is the same.
// CHECK-DAG: pdl_interp.are_equal %[[DEF_OP_0]], %[[DEF_OP_1]]
pdl.pattern : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%inputOp = pdl.operation -> %type1, %type2
%result1 = pdl.result 0 of %inputOp
%result2 = pdl.result 1 of %inputOp
%root = pdl.operation(%result1, %result2)
pdl.rewrite %root with "rewriter"
}
}
@ -134,12 +168,12 @@ module @switch_result_types {
// CHECK: pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64]
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation -> %type
%root = pdl.operation -> %type
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%type = pdl.type : i64
%root, %result = pdl.operation -> %type
%root = pdl.operation -> %type
pdl.rewrite %root with "rewriter"
}
}
@ -161,13 +195,13 @@ module @predicate_ordering {
pdl.pattern : benefit(1) {
%resultType = pdl.type
pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type)
%root, %result = pdl.operation -> %resultType
%root = pdl.operation -> %resultType
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%resultType = pdl.type
%apply, %applyRes = pdl.operation -> %resultType
%apply = pdl.operation -> %resultType
pdl.rewrite %apply with "rewriter"
}
}

View File

@ -63,7 +63,8 @@ module @operation_operands {
%root = pdl.operation "foo.op"(%operand)
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %result = pdl.operation "foo.op"(%operand) -> %type
%newOp = pdl.operation "foo.op"(%operand) -> %type
%result = pdl.result 0 of %newOp
%newOp1 = pdl.operation "foo.op2"(%result)
pdl.erase %root
}
@ -84,7 +85,8 @@ module @operation_operands {
%root = pdl.operation "foo.op"(%operand)
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %result = pdl.operation "foo.op"(%operand) -> %type
%newOp = pdl.operation "foo.op"(%operand) -> %type
%result = pdl.result 0 of %newOp
%newOp1 = pdl.operation "foo.op2"(%result)
pdl.erase %root
}
@ -101,10 +103,10 @@ module @operation_result_types {
pdl.pattern : benefit(1) {
%rootType = pdl.type
%rootType1 = pdl.type
%root, %results:2 = pdl.operation "foo.op" -> %rootType, %rootType1
%root = pdl.operation "foo.op" -> %rootType, %rootType1
pdl.rewrite %root {
%newType1 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %rootType, %newType1
%newOp = pdl.operation "foo.op" -> %rootType, %newType1
pdl.replace %root with %newOp
}
}
@ -112,23 +114,6 @@ module @operation_result_types {
// -----
// CHECK-LABEL: module @operation_result_types_infer_from_value_replacement
module @operation_result_types_infer_from_value_replacement {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
pdl.pattern : benefit(1) {
%rootType = pdl.type
%root, %result = pdl.operation "foo.op" -> %rootType
pdl.rewrite %root {
%newType = pdl.type
%newOp, %newResult = pdl.operation "foo.op" -> %newType
pdl.replace %root with (%newResult)
}
}
}
// -----
// CHECK-LABEL: module @replace_with_op
module @replace_with_op {
// CHECK: module @rewriters
@ -138,9 +123,9 @@ module @replace_with_op {
// CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> %type
pdl.replace %root with %newOp
}
}
@ -157,9 +142,10 @@ module @replace_with_values {
// CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> %type
%newResult = pdl.result 0 of %newOp
pdl.replace %root with (%newResult)
}
}
@ -192,10 +178,10 @@ module @create_native {
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
pdl.pattern : benefit(1) {
%type = pdl.type
%root, %result = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type
%newOp, %newResult = pdl.operation "foo.op" -> %newType
%newOp = pdl.operation "foo.op" -> %newType
pdl.replace %root with %newOp
}
}

View File

@ -24,7 +24,7 @@ pdl.pattern : benefit(1) {
// expected-error@below {{expected only one of [`type`, `value`] to be set}}
%attr = pdl.attribute : %type 10
%op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type
%op = pdl.operation "foo.op" {"attr" = %attr} -> %type
pdl.rewrite %op with "rewriter"
}
@ -108,7 +108,7 @@ pdl.pattern : benefit(1) {
// expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
// expected-note@below {{result type #0 was not constrained}}
%newOp, %result = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> %type
}
}
@ -147,28 +147,12 @@ pdl.pattern : benefit(1) {
// -----
//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type
// expected-error@below {{to have the same number of results as the replacement operation}}
pdl.replace %root with %newOp
}
}
// -----
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %oldResult = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> %type
%newResult = pdl.result 0 of %newOp
// expected-error@below {{expected no replacement values to be provided when the replacement operation is present}}
"pdl.replace"(%root, %newOp, %newResult) {
@ -179,19 +163,6 @@ pdl.pattern : benefit(1) {
// -----
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type
// expected-error@below {{to have the same number of results as the provided replacement values}}
pdl.replace %root with (%newResult)
}
}
// -----
//===----------------------------------------------------------------------===//
// pdl::RewriteOp
//===----------------------------------------------------------------------===//

View File

@ -8,7 +8,8 @@ pdl.pattern @operations : benefit(1) {
// Operation with attributes and results.
%attribute = pdl.attribute
%type = pdl.type
%op0, %op0_result = pdl.operation {"attr" = %attribute} -> %type
%op0 = pdl.operation {"attr" = %attribute} -> %type
%op0_result = pdl.result 0 of %op0
// Operation with input.
%input = pdl.operand
@ -46,38 +47,23 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) {
pdl.pattern @infer_type_from_operation_replace : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%root = pdl.operation -> %type1, %type2
pdl.rewrite %root {
%type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
%newOp = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with %newOp
}
}
// -----
// Check that the result type of an operation within a rewrite can be inferred
// from a pdl.replace.
pdl.pattern @infer_type_from_result_replace : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
pdl.rewrite %root {
%type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with (%newResults#0, %newResults#1)
}
}
// -----
// Check that the result type of an operation within a rewrite can be inferred
// from a pdl.replace.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%root = pdl.operation -> %type1, %type2
pdl.rewrite %root {
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2
%newOp = pdl.operation "foo.op" -> %type1, %type2
}
}