forked from OSchip/llvm-project
929 lines
37 KiB
C++
929 lines
37 KiB
C++
//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
|
|
#include "../PassDetail.h"
|
|
#include "PredicateTree.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/ScopedHashTable.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdl_to_pdl_interp;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PatternLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class generators operations within the PDL Interpreter dialect from a
|
|
/// given module containing PDL pattern operations.
|
|
struct PatternLowering {
|
|
public:
|
|
PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
|
|
|
|
/// Generate code for matching and rewriting based on the pattern operations
|
|
/// within the module.
|
|
void lower(ModuleOp module);
|
|
|
|
private:
|
|
using ValueMap = llvm::ScopedHashTable<Position *, Value>;
|
|
using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
|
|
|
|
/// Generate interpreter operations for the tree rooted at the given matcher
|
|
/// node, in the specified region.
|
|
Block *generateMatcher(MatcherNode &node, Region ®ion);
|
|
|
|
/// Get or create an access to the provided positional value in the current
|
|
/// block. This operation may mutate the provided block pointer if nested
|
|
/// regions (i.e., pdl_interp.iterate) are required.
|
|
Value getValueAt(Block *¤tBlock, Position *pos);
|
|
|
|
/// Create the interpreter predicate operations. This operation may mutate the
|
|
/// provided current block pointer if nested regions (iterates) are required.
|
|
void generate(BoolNode *boolNode, Block *¤tBlock, Value val);
|
|
|
|
/// Create the interpreter switch / predicate operations, with several case
|
|
/// destinations. This operation never mutates the provided current block
|
|
/// pointer, because the switch operation does not need Values beyond `val`.
|
|
void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
|
|
|
|
/// Create the interpreter operations to record a successful pattern match
|
|
/// using the contained root operation. This operation may mutate the current
|
|
/// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
|
|
void generate(SuccessNode *successNode, Block *¤tBlock);
|
|
|
|
/// Generate a rewriter function for the given pattern operation, and returns
|
|
/// a reference to that function.
|
|
SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
|
|
SmallVectorImpl<Position *> &usedMatchValues);
|
|
|
|
/// Generate the rewriter code for the given operation.
|
|
void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::AttributeOp attrOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::EraseOp eraseOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::OperationOp operationOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::ReplaceOp replaceOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::ResultOp resultOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::ResultsOp resultOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::TypeOp typeOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
void generateRewriter(pdl::TypesOp typeOp,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
|
|
/// Generate the values used for resolving the result types of an operation
|
|
/// created within a dag rewriter region.
|
|
void generateOperationResultTypeRewriter(
|
|
pdl::OperationOp op, SmallVectorImpl<Value> &types,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue);
|
|
|
|
/// A builder to use when generating interpreter operations.
|
|
OpBuilder builder;
|
|
|
|
/// The matcher function used for all match related logic within PDL patterns.
|
|
FuncOp matcherFunc;
|
|
|
|
/// The rewriter module containing the all rewrite related logic within PDL
|
|
/// patterns.
|
|
ModuleOp rewriterModule;
|
|
|
|
/// The symbol table of the rewriter module used for insertion.
|
|
SymbolTable rewriterSymbolTable;
|
|
|
|
/// A scoped map connecting a position with the corresponding interpreter
|
|
/// value.
|
|
ValueMap values;
|
|
|
|
/// A stack of blocks used as the failure destination for matcher nodes that
|
|
/// don't have an explicit failure path.
|
|
SmallVector<Block *, 8> failureBlockStack;
|
|
|
|
/// A mapping between values defined in a pattern match, and the corresponding
|
|
/// positional value.
|
|
DenseMap<Value, Position *> valueToPosition;
|
|
|
|
/// The set of operation values whose whose location will be used for newly
|
|
/// generated operations.
|
|
SetVector<Value> locOps;
|
|
};
|
|
} // namespace
|
|
|
|
PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
|
|
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
|
|
rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
|
|
|
|
void PatternLowering::lower(ModuleOp module) {
|
|
PredicateUniquer predicateUniquer;
|
|
PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
|
|
|
|
// Define top-level scope for the arguments to the matcher function.
|
|
ValueMapScope topLevelValueScope(values);
|
|
|
|
// Insert the root operation, i.e. argument to the matcher, at the root
|
|
// position.
|
|
Block *matcherEntryBlock = matcherFunc.addEntryBlock();
|
|
values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
|
|
|
|
// Generate a root matcher node from the provided PDL module.
|
|
std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
|
|
module, predicateBuilder, valueToPosition);
|
|
Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
|
|
assert(failureBlockStack.empty() && "failed to empty the stack");
|
|
|
|
// After generation, merged the first matched block into the entry.
|
|
matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
|
|
firstMatcherBlock->getOperations());
|
|
firstMatcherBlock->erase();
|
|
}
|
|
|
|
Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
|
|
// Push a new scope for the values used by this matcher.
|
|
Block *block = ®ion.emplaceBlock();
|
|
ValueMapScope scope(values);
|
|
|
|
// If this is the return node, simply insert the corresponding interpreter
|
|
// finalize.
|
|
if (isa<ExitNode>(node)) {
|
|
builder.setInsertionPointToEnd(block);
|
|
builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
|
|
return block;
|
|
}
|
|
|
|
// Get the next block in the match sequence.
|
|
// This is intentionally executed first, before we get the value for the
|
|
// position associated with the node, so that we preserve an "there exist"
|
|
// semantics: if getting a value requires an upward traversal (going from a
|
|
// value to its consumers), we want to perform the check on all the consumers
|
|
// before we pass control to the failure node.
|
|
std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
|
|
Block *failureBlock;
|
|
if (failureNode) {
|
|
failureBlock = generateMatcher(*failureNode, region);
|
|
failureBlockStack.push_back(failureBlock);
|
|
} else {
|
|
assert(!failureBlockStack.empty() && "expected valid failure block");
|
|
failureBlock = failureBlockStack.back();
|
|
}
|
|
|
|
// If this node contains a position, get the corresponding value for this
|
|
// block.
|
|
Block *currentBlock = block;
|
|
Position *position = node.getPosition();
|
|
Value val = position ? getValueAt(currentBlock, position) : Value();
|
|
|
|
// If this value corresponds to an operation, record that we are going to use
|
|
// its location as part of a fused location.
|
|
bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
|
|
if (isOperationValue)
|
|
locOps.insert(val);
|
|
|
|
// Dispatch to the correct method based on derived node type.
|
|
TypeSwitch<MatcherNode *>(&node)
|
|
.Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
|
|
this->generate(derivedNode, currentBlock, val);
|
|
})
|
|
.Case([&](SuccessNode *successNode) {
|
|
generate(successNode, currentBlock);
|
|
});
|
|
|
|
// Pop all the failure blocks that were inserted due to nesting of
|
|
// pdl_interp.iterate.
|
|
while (failureBlockStack.back() != failureBlock) {
|
|
failureBlockStack.pop_back();
|
|
assert(!failureBlockStack.empty() && "unable to locate failure block");
|
|
}
|
|
|
|
// Pop the new failure block.
|
|
if (failureNode)
|
|
failureBlockStack.pop_back();
|
|
|
|
if (isOperationValue)
|
|
locOps.remove(val);
|
|
|
|
return block;
|
|
}
|
|
|
|
Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
|
|
if (Value val = values.lookup(pos))
|
|
return val;
|
|
|
|
// Get the value for the parent position.
|
|
Value parentVal;
|
|
if (Position *parent = pos->getParent())
|
|
parentVal = getValueAt(currentBlock, parent);
|
|
|
|
// TODO: Use a location from the position.
|
|
Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
|
|
builder.setInsertionPointToEnd(currentBlock);
|
|
Value value;
|
|
switch (pos->getKind()) {
|
|
case Predicates::OperationPos: {
|
|
auto *operationPos = cast<OperationPosition>(pos);
|
|
if (operationPos->isOperandDefiningOp())
|
|
// Standard (downward) traversal which directly follows the defining op.
|
|
value = builder.create<pdl_interp::GetDefiningOpOp>(
|
|
loc, builder.getType<pdl::OperationType>(), parentVal);
|
|
else
|
|
// A passthrough operation position.
|
|
value = parentVal;
|
|
break;
|
|
}
|
|
case Predicates::UsersPos: {
|
|
auto *usersPos = cast<UsersPosition>(pos);
|
|
|
|
// The first operation retrieves the representative value of a range.
|
|
// This applies only when the parent is a range of values and we were
|
|
// requested to use a representative value (e.g., upward traversal).
|
|
if (parentVal.getType().isa<pdl::RangeType>() &&
|
|
usersPos->useRepresentative())
|
|
value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
|
|
else
|
|
value = parentVal;
|
|
|
|
// The second operation retrieves the users.
|
|
value = builder.create<pdl_interp::GetUsersOp>(loc, value);
|
|
break;
|
|
}
|
|
case Predicates::ForEachPos: {
|
|
assert(!failureBlockStack.empty() && "expected valid failure block");
|
|
auto foreach = builder.create<pdl_interp::ForEachOp>(
|
|
loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
|
|
value = foreach.getLoopVariable();
|
|
|
|
// Create the continuation block.
|
|
Block *continueBlock = builder.createBlock(&foreach.region());
|
|
builder.create<pdl_interp::ContinueOp>(loc);
|
|
failureBlockStack.push_back(continueBlock);
|
|
|
|
currentBlock = &foreach.region().front();
|
|
break;
|
|
}
|
|
case Predicates::OperandPos: {
|
|
auto *operandPos = cast<OperandPosition>(pos);
|
|
value = builder.create<pdl_interp::GetOperandOp>(
|
|
loc, builder.getType<pdl::ValueType>(), parentVal,
|
|
operandPos->getOperandNumber());
|
|
break;
|
|
}
|
|
case Predicates::OperandGroupPos: {
|
|
auto *operandPos = cast<OperandGroupPosition>(pos);
|
|
Type valueTy = builder.getType<pdl::ValueType>();
|
|
value = builder.create<pdl_interp::GetOperandsOp>(
|
|
loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
|
|
parentVal, operandPos->getOperandGroupNumber());
|
|
break;
|
|
}
|
|
case Predicates::AttributePos: {
|
|
auto *attrPos = cast<AttributePosition>(pos);
|
|
value = builder.create<pdl_interp::GetAttributeOp>(
|
|
loc, builder.getType<pdl::AttributeType>(), parentVal,
|
|
attrPos->getName().strref());
|
|
break;
|
|
}
|
|
case Predicates::TypePos: {
|
|
if (parentVal.getType().isa<pdl::AttributeType>())
|
|
value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
|
|
else
|
|
value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
|
|
break;
|
|
}
|
|
case Predicates::ResultPos: {
|
|
auto *resPos = cast<ResultPosition>(pos);
|
|
value = builder.create<pdl_interp::GetResultOp>(
|
|
loc, builder.getType<pdl::ValueType>(), parentVal,
|
|
resPos->getResultNumber());
|
|
break;
|
|
}
|
|
case Predicates::ResultGroupPos: {
|
|
auto *resPos = cast<ResultGroupPosition>(pos);
|
|
Type valueTy = builder.getType<pdl::ValueType>();
|
|
value = builder.create<pdl_interp::GetResultsOp>(
|
|
loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
|
|
parentVal, resPos->getResultGroupNumber());
|
|
break;
|
|
}
|
|
case Predicates::AttributeLiteralPos: {
|
|
auto *attrPos = cast<AttributeLiteralPosition>(pos);
|
|
value =
|
|
builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
|
|
break;
|
|
}
|
|
case Predicates::TypeLiteralPos: {
|
|
auto *typePos = cast<TypeLiteralPosition>(pos);
|
|
Attribute rawTypeAttr = typePos->getValue();
|
|
if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
|
|
value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
|
|
else
|
|
value = builder.create<pdl_interp::CreateTypesOp>(
|
|
loc, rawTypeAttr.cast<ArrayAttr>());
|
|
break;
|
|
}
|
|
default:
|
|
llvm_unreachable("Generating unknown Position getter");
|
|
break;
|
|
}
|
|
|
|
values.insert(pos, value);
|
|
return value;
|
|
}
|
|
|
|
void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|
Value val) {
|
|
Location loc = val.getLoc();
|
|
Qualifier *question = boolNode->getQuestion();
|
|
Qualifier *answer = boolNode->getAnswer();
|
|
Region *region = currentBlock->getParent();
|
|
|
|
// Execute the getValue queries first, so that we create success
|
|
// matcher in the correct (possibly nested) region.
|
|
SmallVector<Value> args;
|
|
if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
|
|
args = {getValueAt(currentBlock, equalToQuestion->getValue())};
|
|
} else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
|
|
for (Position *position : cstQuestion->getArgs())
|
|
args.push_back(getValueAt(currentBlock, position));
|
|
}
|
|
|
|
// Generate the matcher in the current (potentially nested) region
|
|
// and get the failure successor.
|
|
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
|
|
Block *failure = failureBlockStack.back();
|
|
|
|
// Finally, create the predicate.
|
|
builder.setInsertionPointToEnd(currentBlock);
|
|
Predicates::Kind kind = question->getKind();
|
|
switch (kind) {
|
|
case Predicates::IsNotNullQuestion:
|
|
builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
|
|
break;
|
|
case Predicates::OperationNameQuestion: {
|
|
auto *opNameAnswer = cast<OperationNameAnswer>(answer);
|
|
builder.create<pdl_interp::CheckOperationNameOp>(
|
|
loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
|
|
break;
|
|
}
|
|
case Predicates::TypeQuestion: {
|
|
auto *ans = cast<TypeAnswer>(answer);
|
|
if (val.getType().isa<pdl::RangeType>())
|
|
builder.create<pdl_interp::CheckTypesOp>(
|
|
loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
|
|
else
|
|
builder.create<pdl_interp::CheckTypeOp>(
|
|
loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
|
|
break;
|
|
}
|
|
case Predicates::AttributeQuestion: {
|
|
auto *ans = cast<AttributeAnswer>(answer);
|
|
builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
|
|
success, failure);
|
|
break;
|
|
}
|
|
case Predicates::OperandCountAtLeastQuestion:
|
|
case Predicates::OperandCountQuestion:
|
|
builder.create<pdl_interp::CheckOperandCountOp>(
|
|
loc, val, cast<UnsignedAnswer>(answer)->getValue(),
|
|
/*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
|
|
success, failure);
|
|
break;
|
|
case Predicates::ResultCountAtLeastQuestion:
|
|
case Predicates::ResultCountQuestion:
|
|
builder.create<pdl_interp::CheckResultCountOp>(
|
|
loc, val, cast<UnsignedAnswer>(answer)->getValue(),
|
|
/*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
|
|
success, failure);
|
|
break;
|
|
case Predicates::EqualToQuestion: {
|
|
bool trueAnswer = isa<TrueAnswer>(answer);
|
|
builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
|
|
trueAnswer ? success : failure,
|
|
trueAnswer ? failure : success);
|
|
break;
|
|
}
|
|
case Predicates::ConstraintQuestion: {
|
|
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
|
builder.create<pdl_interp::ApplyConstraintOp>(
|
|
loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
|
|
failure);
|
|
break;
|
|
}
|
|
default:
|
|
llvm_unreachable("Generating unknown Predicate operation");
|
|
}
|
|
}
|
|
|
|
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
|
|
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
|
|
llvm::MapVector<Qualifier *, Block *> &dests) {
|
|
std::vector<ValT> values;
|
|
std::vector<Block *> blocks;
|
|
values.reserve(dests.size());
|
|
blocks.reserve(dests.size());
|
|
for (const auto &it : dests) {
|
|
blocks.push_back(it.second);
|
|
values.push_back(cast<PredT>(it.first)->getValue());
|
|
}
|
|
builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
|
|
}
|
|
|
|
void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
|
|
Value val) {
|
|
Qualifier *question = switchNode->getQuestion();
|
|
Region *region = currentBlock->getParent();
|
|
Block *defaultDest = failureBlockStack.back();
|
|
|
|
// If the switch question is not an exact answer, i.e. for the `at_least`
|
|
// cases, we generate a special block sequence.
|
|
Predicates::Kind kind = question->getKind();
|
|
if (kind == Predicates::OperandCountAtLeastQuestion ||
|
|
kind == Predicates::ResultCountAtLeastQuestion) {
|
|
// Order the children such that the cases are in reverse numerical order.
|
|
SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
|
|
llvm::seq<unsigned>(0, switchNode->getChildren().size()));
|
|
llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
|
|
return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
|
|
cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
|
|
});
|
|
|
|
// Build the destination for each child using the next highest child as a
|
|
// a failure destination. This essentially creates the following control
|
|
// flow:
|
|
//
|
|
// if (operand_count < 1)
|
|
// goto failure
|
|
// if (child1.match())
|
|
// ...
|
|
//
|
|
// if (operand_count < 2)
|
|
// goto failure
|
|
// if (child2.match())
|
|
// ...
|
|
//
|
|
// failure:
|
|
// ...
|
|
//
|
|
failureBlockStack.push_back(defaultDest);
|
|
Location loc = val.getLoc();
|
|
for (unsigned idx : sortedChildren) {
|
|
auto &child = switchNode->getChild(idx);
|
|
Block *childBlock = generateMatcher(*child.second, *region);
|
|
Block *predicateBlock = builder.createBlock(childBlock);
|
|
builder.setInsertionPointToEnd(predicateBlock);
|
|
unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
|
|
switch (kind) {
|
|
case Predicates::OperandCountAtLeastQuestion:
|
|
builder.create<pdl_interp::CheckOperandCountOp>(
|
|
loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
|
|
break;
|
|
case Predicates::ResultCountAtLeastQuestion:
|
|
builder.create<pdl_interp::CheckResultCountOp>(
|
|
loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
|
|
break;
|
|
default:
|
|
llvm_unreachable("Generating invalid AtLeast operation");
|
|
}
|
|
failureBlockStack.back() = predicateBlock;
|
|
}
|
|
Block *firstPredicateBlock = failureBlockStack.pop_back_val();
|
|
currentBlock->getOperations().splice(currentBlock->end(),
|
|
firstPredicateBlock->getOperations());
|
|
firstPredicateBlock->erase();
|
|
return;
|
|
}
|
|
|
|
// Otherwise, generate each of the children and generate an interpreter
|
|
// switch.
|
|
llvm::MapVector<Qualifier *, Block *> children;
|
|
for (auto &it : switchNode->getChildren())
|
|
children.insert({it.first, generateMatcher(*it.second, *region)});
|
|
builder.setInsertionPointToEnd(currentBlock);
|
|
|
|
switch (question->getKind()) {
|
|
case Predicates::OperandCountQuestion:
|
|
return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
|
|
int32_t>(val, defaultDest, builder, children);
|
|
case Predicates::ResultCountQuestion:
|
|
return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
|
|
int32_t>(val, defaultDest, builder, children);
|
|
case Predicates::OperationNameQuestion:
|
|
return createSwitchOp<pdl_interp::SwitchOperationNameOp,
|
|
OperationNameAnswer>(val, defaultDest, builder,
|
|
children);
|
|
case Predicates::TypeQuestion:
|
|
if (val.getType().isa<pdl::RangeType>()) {
|
|
return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
|
|
val, defaultDest, builder, children);
|
|
}
|
|
return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
|
|
val, defaultDest, builder, children);
|
|
case Predicates::AttributeQuestion:
|
|
return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
|
|
val, defaultDest, builder, children);
|
|
default:
|
|
llvm_unreachable("Generating unknown switch predicate.");
|
|
}
|
|
}
|
|
|
|
void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
|
|
pdl::PatternOp pattern = successNode->getPattern();
|
|
Value root = successNode->getRoot();
|
|
|
|
// Generate a rewriter for the pattern this success node represents, and track
|
|
// any values used from the match region.
|
|
SmallVector<Position *, 8> usedMatchValues;
|
|
SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
|
|
|
|
// Process any values used in the rewrite that are defined in the match.
|
|
std::vector<Value> mappedMatchValues;
|
|
mappedMatchValues.reserve(usedMatchValues.size());
|
|
for (Position *position : usedMatchValues)
|
|
mappedMatchValues.push_back(getValueAt(currentBlock, position));
|
|
|
|
// Collect the set of operations generated by the rewriter.
|
|
SmallVector<StringRef, 4> generatedOps;
|
|
for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
|
|
generatedOps.push_back(*op.name());
|
|
ArrayAttr generatedOpsAttr;
|
|
if (!generatedOps.empty())
|
|
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
|
|
|
|
// Grab the root kind if present.
|
|
StringAttr rootKindAttr;
|
|
if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
|
|
if (Optional<StringRef> rootKind = rootOp.name())
|
|
rootKindAttr = builder.getStringAttr(*rootKind);
|
|
|
|
builder.setInsertionPointToEnd(currentBlock);
|
|
builder.create<pdl_interp::RecordMatchOp>(
|
|
pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
|
|
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
|
|
failureBlockStack.back());
|
|
}
|
|
|
|
SymbolRefAttr PatternLowering::generateRewriter(
|
|
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
|
|
FuncOp rewriterFunc =
|
|
FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
|
|
builder.getFunctionType(llvm::None, llvm::None));
|
|
rewriterSymbolTable.insert(rewriterFunc);
|
|
|
|
// Generate the rewriter function body.
|
|
builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
|
|
|
|
// Map an input operand of the pattern to a generated interpreter value.
|
|
DenseMap<Value, Value> rewriteValues;
|
|
auto mapRewriteValue = [&](Value oldValue) {
|
|
Value &newValue = rewriteValues[oldValue];
|
|
if (newValue)
|
|
return newValue;
|
|
|
|
// Prefer materializing constants directly when possible.
|
|
Operation *oldOp = oldValue.getDefiningOp();
|
|
if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
|
|
if (Attribute value = attrOp.valueAttr()) {
|
|
return newValue = builder.create<pdl_interp::CreateAttributeOp>(
|
|
attrOp.getLoc(), value);
|
|
}
|
|
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
|
|
if (TypeAttr type = typeOp.typeAttr()) {
|
|
return newValue = builder.create<pdl_interp::CreateTypeOp>(
|
|
typeOp.getLoc(), type);
|
|
}
|
|
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
|
|
if (ArrayAttr type = typeOp.typesAttr()) {
|
|
return newValue = builder.create<pdl_interp::CreateTypesOp>(
|
|
typeOp.getLoc(), typeOp.getType(), type);
|
|
}
|
|
}
|
|
|
|
// Otherwise, add this as an input to the rewriter.
|
|
Position *inputPos = valueToPosition.lookup(oldValue);
|
|
assert(inputPos && "expected value to be a pattern input");
|
|
usedMatchValues.push_back(inputPos);
|
|
return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
|
|
oldValue.getLoc());
|
|
};
|
|
|
|
// If this is a custom rewriter, simply dispatch to the registered rewrite
|
|
// method.
|
|
pdl::RewriteOp rewriter = pattern.getRewriter();
|
|
if (StringAttr rewriteName = rewriter.nameAttr()) {
|
|
SmallVector<Value> args;
|
|
if (rewriter.root())
|
|
args.push_back(mapRewriteValue(rewriter.root()));
|
|
auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
|
|
args.append(mappedArgs.begin(), mappedArgs.end());
|
|
builder.create<pdl_interp::ApplyRewriteOp>(
|
|
rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
|
|
rewriter.externalConstParamsAttr());
|
|
} else {
|
|
// Otherwise this is a dag rewriter defined using PDL operations.
|
|
for (Operation &rewriteOp : *rewriter.getBody()) {
|
|
llvm::TypeSwitch<Operation *>(&rewriteOp)
|
|
.Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
|
|
pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
|
|
pdl::TypeOp, pdl::TypesOp>([&](auto op) {
|
|
this->generateRewriter(op, rewriteValues, mapRewriteValue);
|
|
});
|
|
}
|
|
}
|
|
|
|
// Update the signature of the rewrite function.
|
|
rewriterFunc.setType(builder.getFunctionType(
|
|
llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
|
|
/*results=*/llvm::None));
|
|
|
|
builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
|
|
return SymbolRefAttr::get(
|
|
builder.getContext(),
|
|
pdl_interp::PDLInterpDialect::getRewriterModuleName(),
|
|
SymbolRefAttr::get(rewriterFunc));
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
SmallVector<Value, 2> arguments;
|
|
for (Value argument : rewriteOp.args())
|
|
arguments.push_back(mapRewriteValue(argument));
|
|
auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
|
|
rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
|
|
arguments, rewriteOp.constParamsAttr());
|
|
for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
|
|
rewriteValues[std::get<0>(it)] = std::get<1>(it);
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
|
|
attrOp.getLoc(), attrOp.valueAttr());
|
|
rewriteValues[attrOp] = newAttr;
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
|
|
mapRewriteValue(eraseOp.operation()));
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
SmallVector<Value, 4> operands;
|
|
for (Value operand : operationOp.operands())
|
|
operands.push_back(mapRewriteValue(operand));
|
|
|
|
SmallVector<Value, 4> attributes;
|
|
for (Value attr : operationOp.attributes())
|
|
attributes.push_back(mapRewriteValue(attr));
|
|
|
|
SmallVector<Value, 2> types;
|
|
generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
|
|
mapRewriteValue);
|
|
|
|
// Create the new operation.
|
|
Location loc = operationOp.getLoc();
|
|
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
|
|
loc, *operationOp.name(), types, operands, attributes,
|
|
operationOp.attributeNames());
|
|
rewriteValues[operationOp.op()] = createdOp;
|
|
|
|
// Generate accesses for any results that have their types constrained.
|
|
// Handle the case where there is a single range representing all of the
|
|
// result types.
|
|
OperandRange resultTys = operationOp.types();
|
|
if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
|
|
Value &type = rewriteValues[resultTys[0]];
|
|
if (!type) {
|
|
auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
|
|
type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Otherwise, populate the individual results.
|
|
bool seenVariableLength = false;
|
|
Type valueTy = builder.getType<pdl::ValueType>();
|
|
Type valueRangeTy = pdl::RangeType::get(valueTy);
|
|
for (const auto &it : llvm::enumerate(resultTys)) {
|
|
Value &type = rewriteValues[it.value()];
|
|
if (type)
|
|
continue;
|
|
bool isVariadic = it.value().getType().isa<pdl::RangeType>();
|
|
seenVariableLength |= isVariadic;
|
|
|
|
// After a variable length result has been seen, we need to use result
|
|
// groups because the exact index of the result is not statically known.
|
|
Value resultVal;
|
|
if (seenVariableLength)
|
|
resultVal = builder.create<pdl_interp::GetResultsOp>(
|
|
loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
|
|
else
|
|
resultVal = builder.create<pdl_interp::GetResultOp>(
|
|
loc, valueTy, createdOp, it.index());
|
|
type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
|
|
}
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
SmallVector<Value, 4> replOperands;
|
|
|
|
// If the replacement was another operation, get its results. `pdl` allows
|
|
// for using an operation for simplicitly, but the interpreter isn't as
|
|
// user facing.
|
|
if (Value replOp = replaceOp.replOperation()) {
|
|
// Don't use replace if we know the replaced operation has no results.
|
|
auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
|
|
if (!opOp || !opOp.types().empty()) {
|
|
replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
|
|
replOp.getLoc(), mapRewriteValue(replOp)));
|
|
}
|
|
} else {
|
|
for (Value operand : replaceOp.replValues())
|
|
replOperands.push_back(mapRewriteValue(operand));
|
|
}
|
|
|
|
// If there are no replacement values, just create an erase instead.
|
|
if (replOperands.empty()) {
|
|
builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
|
|
mapRewriteValue(replaceOp.operation()));
|
|
return;
|
|
}
|
|
|
|
builder.create<pdl_interp::ReplaceOp>(
|
|
replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
|
|
resultOp.getLoc(), builder.getType<pdl::ValueType>(),
|
|
mapRewriteValue(resultOp.parent()), resultOp.index());
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
|
|
resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
|
|
resultOp.index());
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
|
|
// type.
|
|
if (TypeAttr typeAttr = typeOp.typeAttr()) {
|
|
rewriteValues[typeOp] =
|
|
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
|
|
}
|
|
}
|
|
|
|
void PatternLowering::generateRewriter(
|
|
pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
|
|
// type.
|
|
if (ArrayAttr typeAttr = typeOp.typesAttr()) {
|
|
rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
|
|
typeOp.getLoc(), typeOp.getType(), typeAttr);
|
|
}
|
|
}
|
|
|
|
void PatternLowering::generateOperationResultTypeRewriter(
|
|
pdl::OperationOp op, SmallVectorImpl<Value> &types,
|
|
DenseMap<Value, Value> &rewriteValues,
|
|
function_ref<Value(Value)> mapRewriteValue) {
|
|
// Look for an operation that was replaced by `op`. The result types will be
|
|
// inferred from the results that were replaced.
|
|
Block *rewriterBlock = op->getBlock();
|
|
for (OpOperand &use : op.op().getUses()) {
|
|
// Check that the use corresponds to a ReplaceOp and that it is the
|
|
// replacement value, not the operation being replaced.
|
|
pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
|
|
if (!replOpUser || use.getOperandNumber() == 0)
|
|
continue;
|
|
// Make sure the replaced operation was defined before this one.
|
|
Value replOpVal = replOpUser.operation();
|
|
Operation *replacedOp = replOpVal.getDefiningOp();
|
|
if (replacedOp->getBlock() == rewriterBlock &&
|
|
!replacedOp->isBeforeInBlock(op))
|
|
continue;
|
|
|
|
Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
|
|
replacedOp->getLoc(), mapRewriteValue(replOpVal));
|
|
types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
|
|
replacedOp->getLoc(), replacedOpResults));
|
|
return;
|
|
}
|
|
|
|
// Check if the operation has type inference support.
|
|
if (op.hasTypeInference()) {
|
|
types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
|
|
return;
|
|
}
|
|
|
|
// Otherwise, handle inference for each of the result types individually.
|
|
OperandRange resultTypeValues = op.types();
|
|
types.reserve(resultTypeValues.size());
|
|
for (const auto &it : llvm::enumerate(resultTypeValues)) {
|
|
Value resultType = it.value();
|
|
|
|
// Check for an already translated value.
|
|
if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
|
|
types.push_back(existingRewriteValue);
|
|
continue;
|
|
}
|
|
|
|
// Check for an input from the matcher.
|
|
if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
|
|
types.push_back(mapRewriteValue(resultType));
|
|
continue;
|
|
}
|
|
|
|
// The verifier asserts that the result types of each pdl.operation can be
|
|
// inferred. If we reach here, there is a bug either in the logic above or
|
|
// in the verifier for pdl.operation.
|
|
op->emitOpError() << "unable to infer result type for operation";
|
|
llvm_unreachable("unable to infer result type for operation");
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion Pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct PDLToPDLInterpPass
|
|
: public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
|
|
void runOnOperation() final;
|
|
};
|
|
} // namespace
|
|
|
|
/// Convert the given module containing PDL pattern operations into a PDL
|
|
/// Interpreter operations.
|
|
void PDLToPDLInterpPass::runOnOperation() {
|
|
ModuleOp module = getOperation();
|
|
|
|
// Create the main matcher function This function contains all of the match
|
|
// related functionality from patterns in the module.
|
|
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
|
|
FuncOp matcherFunc = builder.create<FuncOp>(
|
|
module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
|
|
builder.getFunctionType(builder.getType<pdl::OperationType>(),
|
|
/*results=*/llvm::None),
|
|
/*attrs=*/llvm::None);
|
|
|
|
// Create a nested module to hold the functions invoked for rewriting the IR
|
|
// after a successful match.
|
|
ModuleOp rewriterModule = builder.create<ModuleOp>(
|
|
module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
|
|
|
|
// Generate the code for the patterns within the module.
|
|
PatternLowering generator(matcherFunc, rewriterModule);
|
|
generator.lower(module);
|
|
|
|
// After generation, delete all of the pattern operations.
|
|
for (pdl::PatternOp pattern :
|
|
llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
|
|
pattern.erase();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
|
|
return std::make_unique<PDLToPDLInterpPass>();
|
|
}
|