[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file implements MLIR to byte-code generation and the interpreter.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "ByteCode.h"
|
|
|
|
#include "mlir/Analysis/Liveness.h"
|
|
|
|
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
|
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
2020-12-04 09:21:32 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
#include "mlir/IR/RegionGraphTraits.h"
|
|
|
|
#include "llvm/ADT/IntervalMap.h"
|
|
|
|
#include "llvm/ADT/PostOrderIterator.h"
|
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
|
|
|
|
#define DEBUG_TYPE "pdl-bytecode"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::detail;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PDLByteCodePattern
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
|
|
|
|
ByteCodeAddr rewriterAddr) {
|
|
|
|
SmallVector<StringRef, 8> generatedOps;
|
|
|
|
if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
|
|
|
|
generatedOps =
|
|
|
|
llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
|
|
|
|
|
|
|
|
PatternBenefit benefit = matchOp.benefit();
|
|
|
|
MLIRContext *ctx = matchOp.getContext();
|
|
|
|
|
|
|
|
// Check to see if this is pattern matches a specific operation type.
|
|
|
|
if (Optional<StringRef> rootKind = matchOp.rootKind())
|
|
|
|
return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
|
|
|
|
ctx);
|
|
|
|
return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
|
|
|
|
MatchAnyOpTypeTag());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PDLByteCodeMutableState
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
|
|
|
|
/// to the position of the pattern within the range returned by
|
|
|
|
/// `PDLByteCode::getPatterns`.
|
|
|
|
void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
|
|
|
|
PatternBenefit benefit) {
|
|
|
|
currentPatternBenefits[patternIndex] = benefit;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Bytecode OpCodes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
enum OpCode : ByteCodeField {
|
|
|
|
/// Apply an externally registered constraint.
|
|
|
|
ApplyConstraint,
|
|
|
|
/// Apply an externally registered rewrite.
|
|
|
|
ApplyRewrite,
|
|
|
|
/// Check if two generic values are equal.
|
|
|
|
AreEqual,
|
|
|
|
/// Unconditional branch.
|
|
|
|
Branch,
|
|
|
|
/// Compare the operand count of an operation with a constant.
|
|
|
|
CheckOperandCount,
|
|
|
|
/// Compare the name of an operation with a constant.
|
|
|
|
CheckOperationName,
|
|
|
|
/// Compare the result count of an operation with a constant.
|
|
|
|
CheckResultCount,
|
|
|
|
/// Invoke a native creation method.
|
|
|
|
CreateNative,
|
|
|
|
/// Create an operation.
|
|
|
|
CreateOperation,
|
|
|
|
/// Erase an operation.
|
|
|
|
EraseOp,
|
|
|
|
/// Terminate a matcher or rewrite sequence.
|
|
|
|
Finalize,
|
|
|
|
/// Get a specific attribute of an operation.
|
|
|
|
GetAttribute,
|
|
|
|
/// Get the type of an attribute.
|
|
|
|
GetAttributeType,
|
|
|
|
/// Get the defining operation of a value.
|
|
|
|
GetDefiningOp,
|
|
|
|
/// Get a specific operand of an operation.
|
|
|
|
GetOperand0,
|
|
|
|
GetOperand1,
|
|
|
|
GetOperand2,
|
|
|
|
GetOperand3,
|
|
|
|
GetOperandN,
|
|
|
|
/// Get a specific result of an operation.
|
|
|
|
GetResult0,
|
|
|
|
GetResult1,
|
|
|
|
GetResult2,
|
|
|
|
GetResult3,
|
|
|
|
GetResultN,
|
|
|
|
/// Get the type of a value.
|
|
|
|
GetValueType,
|
|
|
|
/// Check if a generic value is not null.
|
|
|
|
IsNotNull,
|
|
|
|
/// Record a successful pattern match.
|
|
|
|
RecordMatch,
|
|
|
|
/// Replace an operation.
|
|
|
|
ReplaceOp,
|
|
|
|
/// Compare an attribute with a set of constants.
|
|
|
|
SwitchAttribute,
|
|
|
|
/// Compare the operand count of an operation with a set of constants.
|
|
|
|
SwitchOperandCount,
|
|
|
|
/// Compare the name of an operation with a set of constants.
|
|
|
|
SwitchOperationName,
|
|
|
|
/// Compare the result count of an operation with a set of constants.
|
|
|
|
SwitchResultCount,
|
|
|
|
/// Compare a type with a set of constants.
|
|
|
|
SwitchType,
|
|
|
|
};
|
|
|
|
|
|
|
|
enum class PDLValueKind { Attribute, Operation, Type, Value };
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ByteCode Generation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Generator
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
struct ByteCodeWriter;
|
|
|
|
|
|
|
|
/// This class represents the main generator for the pattern bytecode.
|
|
|
|
class Generator {
|
|
|
|
public:
|
|
|
|
Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
|
|
|
|
SmallVectorImpl<ByteCodeField> &matcherByteCode,
|
|
|
|
SmallVectorImpl<ByteCodeField> &rewriterByteCode,
|
|
|
|
SmallVectorImpl<PDLByteCodePattern> &patterns,
|
|
|
|
ByteCodeField &maxValueMemoryIndex,
|
|
|
|
llvm::StringMap<PDLConstraintFunction> &constraintFns,
|
|
|
|
llvm::StringMap<PDLCreateFunction> &createFns,
|
|
|
|
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
|
|
|
|
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
|
|
|
|
rewriterByteCode(rewriterByteCode), patterns(patterns),
|
|
|
|
maxValueMemoryIndex(maxValueMemoryIndex) {
|
|
|
|
for (auto it : llvm::enumerate(constraintFns))
|
|
|
|
constraintToMemIndex.try_emplace(it.value().first(), it.index());
|
|
|
|
for (auto it : llvm::enumerate(createFns))
|
|
|
|
nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
|
|
|
|
for (auto it : llvm::enumerate(rewriteFns))
|
|
|
|
externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generate the bytecode for the given PDL interpreter module.
|
|
|
|
void generate(ModuleOp module);
|
|
|
|
|
|
|
|
/// Return the memory index to use for the given value.
|
|
|
|
ByteCodeField &getMemIndex(Value value) {
|
|
|
|
assert(valueToMemIndex.count(value) &&
|
|
|
|
"expected memory index to be assigned");
|
|
|
|
return valueToMemIndex[value];
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return an index to use when referring to the given data that is uniqued in
|
|
|
|
/// the MLIR context.
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
|
|
|
|
getMemIndex(T val) {
|
|
|
|
const void *opaqueVal = val.getAsOpaquePointer();
|
|
|
|
|
|
|
|
// Get or insert a reference to this value.
|
|
|
|
auto it = uniquedDataToMemIndex.try_emplace(
|
|
|
|
opaqueVal, maxValueMemoryIndex + uniquedData.size());
|
|
|
|
if (it.second)
|
|
|
|
uniquedData.push_back(opaqueVal);
|
|
|
|
return it.first->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
/// Allocate memory indices for the results of operations within the matcher
|
|
|
|
/// and rewriters.
|
|
|
|
void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
|
|
|
|
|
|
|
|
/// Generate the bytecode for the given operation.
|
|
|
|
void generate(Operation *op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
|
|
|
|
void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
|
|
|
|
|
|
|
|
/// Mapping from value to its corresponding memory index.
|
|
|
|
DenseMap<Value, ByteCodeField> valueToMemIndex;
|
|
|
|
|
|
|
|
/// Mapping from the name of an externally registered rewrite to its index in
|
|
|
|
/// the bytecode registry.
|
|
|
|
llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
|
|
|
|
|
|
|
|
/// Mapping from the name of an externally registered constraint to its index
|
|
|
|
/// in the bytecode registry.
|
|
|
|
llvm::StringMap<ByteCodeField> constraintToMemIndex;
|
|
|
|
|
|
|
|
/// Mapping from the name of an externally registered creation method to its
|
|
|
|
/// index in the bytecode registry.
|
|
|
|
llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
|
|
|
|
|
|
|
|
/// Mapping from rewriter function name to the bytecode address of the
|
|
|
|
/// rewriter function in byte.
|
|
|
|
llvm::StringMap<ByteCodeAddr> rewriterToAddr;
|
|
|
|
|
|
|
|
/// Mapping from a uniqued storage object to its memory index within
|
|
|
|
/// `uniquedData`.
|
|
|
|
DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
|
|
|
|
|
|
|
|
/// The current MLIR context.
|
|
|
|
MLIRContext *ctx;
|
|
|
|
|
|
|
|
/// Data of the ByteCode class to be populated.
|
|
|
|
std::vector<const void *> &uniquedData;
|
|
|
|
SmallVectorImpl<ByteCodeField> &matcherByteCode;
|
|
|
|
SmallVectorImpl<ByteCodeField> &rewriterByteCode;
|
|
|
|
SmallVectorImpl<PDLByteCodePattern> &patterns;
|
|
|
|
ByteCodeField &maxValueMemoryIndex;
|
|
|
|
};
|
|
|
|
|
|
|
|
/// This class provides utilities for writing a bytecode stream.
|
|
|
|
struct ByteCodeWriter {
|
|
|
|
ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
|
|
|
|
: bytecode(bytecode), generator(generator) {}
|
|
|
|
|
|
|
|
/// Append a field to the bytecode.
|
|
|
|
void append(ByteCodeField field) { bytecode.push_back(field); }
|
2020-12-02 09:08:38 +08:00
|
|
|
void append(OpCode opCode) { bytecode.push_back(opCode); }
|
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
|
|
|
|
/// Append an address to the bytecode.
|
|
|
|
void append(ByteCodeAddr field) {
|
|
|
|
static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
|
|
|
|
"unexpected ByteCode address size");
|
|
|
|
|
|
|
|
ByteCodeField fieldParts[2];
|
|
|
|
std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
|
|
|
|
bytecode.append({fieldParts[0], fieldParts[1]});
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Append a successor range to the bytecode, the exact address will need to
|
|
|
|
/// be resolved later.
|
|
|
|
void append(SuccessorRange successors) {
|
|
|
|
// Add back references to the any successors so that the address can be
|
|
|
|
// resolved later.
|
|
|
|
for (Block *successor : successors) {
|
|
|
|
unresolvedSuccessorRefs[successor].push_back(bytecode.size());
|
|
|
|
append(ByteCodeAddr(0));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Append a range of values that will be read as generic PDLValues.
|
|
|
|
void appendPDLValueList(OperandRange values) {
|
|
|
|
bytecode.push_back(values.size());
|
|
|
|
for (Value value : values) {
|
|
|
|
// Append the type of the value in addition to the value itself.
|
|
|
|
PDLValueKind kind =
|
|
|
|
TypeSwitch<Type, PDLValueKind>(value.getType())
|
|
|
|
.Case<pdl::AttributeType>(
|
|
|
|
[](Type) { return PDLValueKind::Attribute; })
|
|
|
|
.Case<pdl::OperationType>(
|
|
|
|
[](Type) { return PDLValueKind::Operation; })
|
|
|
|
.Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
|
|
|
|
.Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
|
|
|
|
bytecode.push_back(static_cast<ByteCodeField>(kind));
|
|
|
|
append(value);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Check if the given class `T` has an iterator type.
|
|
|
|
template <typename T, typename... Args>
|
|
|
|
using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
|
|
|
|
|
|
|
|
/// Append a value that will be stored in a memory slot and not inline within
|
|
|
|
/// the bytecode.
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
|
|
|
|
std::is_pointer<T>::value>
|
|
|
|
append(T value) {
|
|
|
|
bytecode.push_back(generator.getMemIndex(value));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Append a range of values.
|
|
|
|
template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
|
|
|
|
std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
|
|
|
|
append(T range) {
|
|
|
|
bytecode.push_back(llvm::size(range));
|
|
|
|
for (auto it : range)
|
|
|
|
append(it);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Append a variadic number of fields to the bytecode.
|
|
|
|
template <typename FieldTy, typename Field2Ty, typename... FieldTys>
|
|
|
|
void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
|
|
|
|
append(field);
|
|
|
|
append(field2, fields...);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Successor references in the bytecode that have yet to be resolved.
|
|
|
|
DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
|
|
|
|
|
|
|
|
/// The underlying bytecode buffer.
|
|
|
|
SmallVectorImpl<ByteCodeField> &bytecode;
|
|
|
|
|
|
|
|
/// The main generator producing PDL.
|
|
|
|
Generator &generator;
|
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
void Generator::generate(ModuleOp module) {
|
|
|
|
FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
|
|
|
|
pdl_interp::PDLInterpDialect::getMatcherFunctionName());
|
|
|
|
ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
|
|
|
|
pdl_interp::PDLInterpDialect::getRewriterModuleName());
|
|
|
|
assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
|
|
|
|
|
|
|
|
// Allocate memory indices for the results of operations within the matcher
|
|
|
|
// and rewriters.
|
|
|
|
allocateMemoryIndices(matcherFunc, rewriterModule);
|
|
|
|
|
|
|
|
// Generate code for the rewriter functions.
|
|
|
|
ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
|
|
|
|
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
|
|
|
|
rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
|
|
|
|
for (Operation &op : rewriterFunc.getOps())
|
|
|
|
generate(&op, rewriterByteCodeWriter);
|
|
|
|
}
|
|
|
|
assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
|
|
|
|
"unexpected branches in rewriter function");
|
|
|
|
|
|
|
|
// Generate code for the matcher function.
|
|
|
|
DenseMap<Block *, ByteCodeAddr> blockToAddr;
|
|
|
|
llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
|
|
|
|
ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
|
|
|
|
for (Block *block : rpot) {
|
|
|
|
// Keep track of where this block begins within the matcher function.
|
|
|
|
blockToAddr.try_emplace(block, matcherByteCode.size());
|
|
|
|
for (Operation &op : *block)
|
|
|
|
generate(&op, matcherByteCodeWriter);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Resolve successor references in the matcher.
|
|
|
|
for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
|
|
|
|
ByteCodeAddr addr = blockToAddr[it.first];
|
|
|
|
for (unsigned offsetToFix : it.second)
|
|
|
|
std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void Generator::allocateMemoryIndices(FuncOp matcherFunc,
|
|
|
|
ModuleOp rewriterModule) {
|
|
|
|
// Rewriters use simplistic allocation scheme that simply assigns an index to
|
|
|
|
// each result.
|
|
|
|
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
|
|
|
|
ByteCodeField index = 0;
|
|
|
|
for (BlockArgument arg : rewriterFunc.getArguments())
|
|
|
|
valueToMemIndex.try_emplace(arg, index++);
|
|
|
|
rewriterFunc.getBody().walk([&](Operation *op) {
|
|
|
|
for (Value result : op->getResults())
|
|
|
|
valueToMemIndex.try_emplace(result, index++);
|
|
|
|
});
|
|
|
|
if (index > maxValueMemoryIndex)
|
|
|
|
maxValueMemoryIndex = index;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The matcher function uses a more sophisticated numbering that tries to
|
|
|
|
// minimize the number of memory indices assigned. This is done by determining
|
|
|
|
// a live range of the values within the matcher, then the allocation is just
|
|
|
|
// finding the minimal number of overlapping live ranges. This is essentially
|
|
|
|
// a simplified form of register allocation where we don't necessarily have a
|
|
|
|
// limited number of registers, but we still want to minimize the number used.
|
|
|
|
DenseMap<Operation *, ByteCodeField> opToIndex;
|
|
|
|
matcherFunc.getBody().walk([&](Operation *op) {
|
|
|
|
opToIndex.insert(std::make_pair(op, opToIndex.size()));
|
|
|
|
});
|
|
|
|
|
|
|
|
// Liveness info for each of the defs within the matcher.
|
|
|
|
using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
|
|
|
|
LivenessSet::Allocator allocator;
|
|
|
|
DenseMap<Value, LivenessSet> valueDefRanges;
|
|
|
|
|
|
|
|
// Assign the root operation being matched to slot 0.
|
|
|
|
BlockArgument rootOpArg = matcherFunc.getArgument(0);
|
|
|
|
valueToMemIndex[rootOpArg] = 0;
|
|
|
|
|
|
|
|
// Walk each of the blocks, computing the def interval that the value is used.
|
|
|
|
Liveness matcherLiveness(matcherFunc);
|
|
|
|
for (Block &block : matcherFunc.getBody()) {
|
|
|
|
const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
|
|
|
|
assert(info && "expected liveness info for block");
|
|
|
|
auto processValue = [&](Value value, Operation *firstUseOrDef) {
|
|
|
|
// We don't need to process the root op argument, this value is always
|
|
|
|
// assigned to the first memory slot.
|
|
|
|
if (value == rootOpArg)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Set indices for the range of this block that the value is used.
|
|
|
|
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
|
|
|
|
defRangeIt->second.insert(
|
|
|
|
opToIndex[firstUseOrDef],
|
|
|
|
opToIndex[info->getEndOperation(value, firstUseOrDef)],
|
|
|
|
/*dummyValue*/ 0);
|
|
|
|
};
|
|
|
|
|
|
|
|
// Process the live-ins of this block.
|
|
|
|
for (Value liveIn : info->in())
|
|
|
|
processValue(liveIn, &block.front());
|
|
|
|
|
|
|
|
// Process any new defs within this block.
|
|
|
|
for (Operation &op : block)
|
|
|
|
for (Value result : op.getResults())
|
|
|
|
processValue(result, &op);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Greedily allocate memory slots using the computed def live ranges.
|
|
|
|
std::vector<LivenessSet> allocatedIndices;
|
|
|
|
for (auto &defIt : valueDefRanges) {
|
|
|
|
ByteCodeField &memIndex = valueToMemIndex[defIt.first];
|
|
|
|
LivenessSet &defSet = defIt.second;
|
|
|
|
|
|
|
|
// Try to allocate to an existing index.
|
|
|
|
for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
|
|
|
|
LivenessSet &existingIndex = existingIndexIt.value();
|
|
|
|
llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
|
|
|
|
defIt.second, existingIndex);
|
|
|
|
if (overlaps.valid())
|
|
|
|
continue;
|
|
|
|
// Union the range of the def within the existing index.
|
|
|
|
for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
|
|
|
|
existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
|
|
|
|
memIndex = existingIndexIt.index() + 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If no existing index could be used, add a new one.
|
|
|
|
if (memIndex == 0) {
|
|
|
|
allocatedIndices.emplace_back(allocator);
|
|
|
|
for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
|
|
|
|
allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
|
|
|
|
memIndex = allocatedIndices.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update the max number of indices.
|
|
|
|
ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
|
|
|
|
if (numMatcherIndices > maxValueMemoryIndex)
|
|
|
|
maxValueMemoryIndex = numMatcherIndices;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
|
|
|
|
TypeSwitch<Operation *>(op)
|
|
|
|
.Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
|
|
|
|
pdl_interp::AreEqualOp, pdl_interp::BranchOp,
|
|
|
|
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
|
|
|
|
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
|
|
|
|
pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
|
|
|
|
pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
|
|
|
|
pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
|
|
|
|
pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
|
|
|
|
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
|
|
|
|
pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
|
|
|
|
pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
|
|
|
|
pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
|
|
|
|
pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
|
|
|
|
pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
|
|
|
|
pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
|
|
|
|
[&](auto interpOp) { this->generate(interpOp, writer); })
|
|
|
|
.Default([](Operation *) {
|
|
|
|
llvm_unreachable("unknown `pdl_interp` operation");
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
void Generator::generate(pdl_interp::ApplyConstraintOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
assert(constraintToMemIndex.count(op.name()) &&
|
|
|
|
"expected index for constraint function");
|
|
|
|
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
|
|
|
|
op.constParamsAttr());
|
|
|
|
writer.appendPDLValueList(op.args());
|
|
|
|
writer.append(op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
assert(externalRewriterToMemIndex.count(op.name()) &&
|
|
|
|
"expected index for rewrite function");
|
|
|
|
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
|
|
|
|
op.constParamsAttr(), op.root());
|
|
|
|
writer.appendPDLValueList(op.args());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
|
2020-12-02 10:13:27 +08:00
|
|
|
writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
|
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CheckAttributeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CheckOperandCountOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CheckOperationNameOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::CheckOperationName, op.operation(),
|
|
|
|
OperationName(op.name(), ctx), op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CheckResultCountOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CreateAttributeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
// Simply repoint the memory index of the result to the constant.
|
|
|
|
getMemIndex(op.attribute()) = getMemIndex(op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CreateNativeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
assert(nativeCreateToMemIndex.count(op.name()) &&
|
|
|
|
"expected index for creation function");
|
|
|
|
writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
|
|
|
|
op.result(), op.constParamsAttr());
|
|
|
|
writer.appendPDLValueList(op.args());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CreateOperationOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::CreateOperation, op.operation(),
|
|
|
|
OperationName(op.name(), ctx), op.operands());
|
|
|
|
|
|
|
|
// Add the attributes.
|
|
|
|
OperandRange attributes = op.attributes();
|
|
|
|
writer.append(static_cast<ByteCodeField>(attributes.size()));
|
|
|
|
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
|
|
|
|
writer.append(
|
|
|
|
Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
|
|
|
|
std::get<1>(it));
|
|
|
|
}
|
|
|
|
writer.append(op.types());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
|
|
|
|
// Simply repoint the memory index of the result to the constant.
|
|
|
|
getMemIndex(op.result()) = getMemIndex(op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::EraseOp, op.operation());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::Finalize);
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::GetAttributeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
|
|
|
|
Identifier::get(op.name(), ctx));
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::GetAttributeTypeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::GetAttributeType, op.result(), op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::GetDefiningOpOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
|
|
|
|
uint32_t index = op.index();
|
|
|
|
if (index < 4)
|
|
|
|
writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
|
|
|
|
else
|
|
|
|
writer.append(OpCode::GetOperandN, index);
|
|
|
|
writer.append(op.operation(), op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
|
|
|
|
uint32_t index = op.index();
|
|
|
|
if (index < 4)
|
|
|
|
writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
|
|
|
|
else
|
|
|
|
writer.append(OpCode::GetResultN, index);
|
|
|
|
writer.append(op.operation(), op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::GetValueTypeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::GetValueType, op.result(), op.value());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::InferredTypeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
// InferType maps to a null type as a marker for inferring a result type.
|
|
|
|
getMemIndex(op.type()) = getMemIndex(Type());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
|
|
|
|
ByteCodeField patternIndex = patterns.size();
|
|
|
|
patterns.emplace_back(PDLByteCodePattern::create(
|
|
|
|
op, rewriterToAddr[op.rewriter().getLeafReference()]));
|
2020-12-02 10:13:27 +08:00
|
|
|
writer.append(OpCode::RecordMatch, patternIndex,
|
|
|
|
SuccessorRange(op.getOperation()), op.matchedOps(),
|
|
|
|
op.inputs());
|
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::SwitchAttributeOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::SwitchOperandCountOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::SwitchOperationNameOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
|
|
|
|
return OperationName(attr.cast<StringAttr>().getValue(), ctx);
|
|
|
|
});
|
|
|
|
writer.append(OpCode::SwitchOperationName, op.operation(), cases,
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::SwitchResultCountOp op,
|
|
|
|
ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
|
|
|
|
writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
|
|
|
|
op.getSuccessors());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PDLByteCode
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
PDLByteCode::PDLByteCode(ModuleOp module,
|
|
|
|
llvm::StringMap<PDLConstraintFunction> constraintFns,
|
|
|
|
llvm::StringMap<PDLCreateFunction> createFns,
|
|
|
|
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
|
|
|
|
Generator generator(module.getContext(), uniquedData, matcherByteCode,
|
|
|
|
rewriterByteCode, patterns, maxValueMemoryIndex,
|
|
|
|
constraintFns, createFns, rewriteFns);
|
|
|
|
generator.generate(module);
|
|
|
|
|
|
|
|
// Initialize the external functions.
|
|
|
|
for (auto &it : constraintFns)
|
|
|
|
constraintFunctions.push_back(std::move(it.second));
|
|
|
|
for (auto &it : createFns)
|
|
|
|
createFunctions.push_back(std::move(it.second));
|
|
|
|
for (auto &it : rewriteFns)
|
|
|
|
rewriteFunctions.push_back(std::move(it.second));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Initialize the given state such that it can be used to execute the current
|
|
|
|
/// bytecode.
|
|
|
|
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
|
|
|
|
state.memory.resize(maxValueMemoryIndex, nullptr);
|
|
|
|
state.currentPatternBenefits.reserve(patterns.size());
|
|
|
|
for (const PDLByteCodePattern &pattern : patterns)
|
|
|
|
state.currentPatternBenefits.push_back(pattern.getBenefit());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ByteCode Execution
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// This class provides support for executing a bytecode stream.
|
|
|
|
class ByteCodeExecutor {
|
|
|
|
public:
|
|
|
|
ByteCodeExecutor(const ByteCodeField *curCodeIt,
|
|
|
|
MutableArrayRef<const void *> memory,
|
|
|
|
ArrayRef<const void *> uniquedMemory,
|
|
|
|
ArrayRef<ByteCodeField> code,
|
|
|
|
ArrayRef<PatternBenefit> currentPatternBenefits,
|
|
|
|
ArrayRef<PDLByteCodePattern> patterns,
|
|
|
|
ArrayRef<PDLConstraintFunction> constraintFunctions,
|
|
|
|
ArrayRef<PDLCreateFunction> createFunctions,
|
|
|
|
ArrayRef<PDLRewriteFunction> rewriteFunctions)
|
|
|
|
: curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
|
|
|
|
code(code), currentPatternBenefits(currentPatternBenefits),
|
|
|
|
patterns(patterns), constraintFunctions(constraintFunctions),
|
|
|
|
createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
|
|
|
|
|
|
|
|
/// Start executing the code at the current bytecode index. `matches` is an
|
|
|
|
/// optional field provided when this function is executed in a matching
|
|
|
|
/// context.
|
|
|
|
void execute(PatternRewriter &rewriter,
|
|
|
|
SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
|
|
|
|
Optional<Location> mainRewriteLoc = {});
|
|
|
|
|
|
|
|
private:
|
|
|
|
/// Read a value from the bytecode buffer, optionally skipping a certain
|
|
|
|
/// number of prefix values. These methods always update the buffer to point
|
|
|
|
/// to the next field after the read data.
|
|
|
|
template <typename T = ByteCodeField>
|
|
|
|
T read(size_t skipN = 0) {
|
|
|
|
curCodeIt += skipN;
|
|
|
|
return readImpl<T>();
|
|
|
|
}
|
|
|
|
ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
|
|
|
|
|
|
|
|
/// Read a list of values from the bytecode buffer.
|
|
|
|
template <typename ValueT, typename T>
|
|
|
|
void readList(SmallVectorImpl<T> &list) {
|
|
|
|
list.clear();
|
|
|
|
for (unsigned i = 0, e = read(); i != e; ++i)
|
|
|
|
list.push_back(read<ValueT>());
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Jump to a specific successor based on a predicate value.
|
|
|
|
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
|
|
|
|
/// Jump to a specific successor based on a destination index.
|
|
|
|
void selectJump(size_t destIndex) {
|
|
|
|
curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Handle a switch operation with the provided value and cases.
|
|
|
|
template <typename T, typename RangeT>
|
|
|
|
void handleSwitch(const T &value, RangeT &&cases) {
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << " * Value: " << value << "\n"
|
|
|
|
<< " * Cases: ";
|
|
|
|
llvm::interleaveComma(cases, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n\n";
|
|
|
|
});
|
|
|
|
|
|
|
|
// Check to see if the attribute value is within the case list. Jump to
|
|
|
|
// the correct successor index based on the result.
|
2020-12-03 02:42:40 +08:00
|
|
|
for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
|
|
|
|
if (*it == value)
|
|
|
|
return selectJump(size_t((it - cases.begin()) + 1));
|
|
|
|
selectJump(size_t(0));
|
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Internal implementation of reading various data types from the bytecode
|
|
|
|
/// stream.
|
|
|
|
template <typename T>
|
|
|
|
const void *readFromMemory() {
|
|
|
|
size_t index = *curCodeIt++;
|
|
|
|
|
|
|
|
// If this type is an SSA value, it can only be stored in non-const memory.
|
|
|
|
if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
|
|
|
|
return memory[index];
|
|
|
|
|
|
|
|
// Otherwise, if this index is not inbounds it is uniqued.
|
|
|
|
return uniquedMemory[index - memory.size()];
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
|
|
|
|
return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
|
|
|
|
T>
|
|
|
|
readImpl() {
|
|
|
|
return T(T::getFromOpaquePointer(readFromMemory<T>()));
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
|
|
|
|
switch (static_cast<PDLValueKind>(read())) {
|
|
|
|
case PDLValueKind::Attribute:
|
|
|
|
return read<Attribute>();
|
|
|
|
case PDLValueKind::Operation:
|
|
|
|
return read<Operation *>();
|
|
|
|
case PDLValueKind::Type:
|
|
|
|
return read<Type>();
|
|
|
|
case PDLValueKind::Value:
|
|
|
|
return read<Value>();
|
|
|
|
}
|
2021-01-19 13:59:15 +08:00
|
|
|
llvm_unreachable("unhandled PDLValueKind");
|
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
2020-12-02 06:30:18 +08:00
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
|
|
|
|
static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
|
|
|
|
"unexpected ByteCode address size");
|
|
|
|
ByteCodeAddr result;
|
|
|
|
std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
|
|
|
|
curCodeIt += 2;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
|
|
|
|
return *curCodeIt++;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// The underlying bytecode buffer.
|
|
|
|
const ByteCodeField *curCodeIt;
|
|
|
|
|
|
|
|
/// The current execution memory.
|
|
|
|
MutableArrayRef<const void *> memory;
|
|
|
|
|
|
|
|
/// References to ByteCode data necessary for execution.
|
|
|
|
ArrayRef<const void *> uniquedMemory;
|
|
|
|
ArrayRef<ByteCodeField> code;
|
|
|
|
ArrayRef<PatternBenefit> currentPatternBenefits;
|
|
|
|
ArrayRef<PDLByteCodePattern> patterns;
|
|
|
|
ArrayRef<PDLConstraintFunction> constraintFunctions;
|
|
|
|
ArrayRef<PDLCreateFunction> createFunctions;
|
|
|
|
ArrayRef<PDLRewriteFunction> rewriteFunctions;
|
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
void ByteCodeExecutor::execute(
|
|
|
|
PatternRewriter &rewriter,
|
|
|
|
SmallVectorImpl<PDLByteCode::MatchResult> *matches,
|
|
|
|
Optional<Location> mainRewriteLoc) {
|
|
|
|
while (true) {
|
|
|
|
OpCode opCode = static_cast<OpCode>(read());
|
|
|
|
switch (opCode) {
|
|
|
|
case ApplyConstraint: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
|
|
|
|
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
|
|
|
|
ArrayAttr constParams = read<ArrayAttr>();
|
|
|
|
SmallVector<PDLValue, 16> args;
|
|
|
|
readList<PDLValue>(args);
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << " * Arguments: ";
|
|
|
|
llvm::interleaveComma(args, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
|
|
|
|
});
|
|
|
|
|
|
|
|
// Invoke the constraint and jump to the proper destination.
|
|
|
|
selectJump(succeeded(constraintFn(args, constParams, rewriter)));
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case ApplyRewrite: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
|
|
|
|
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
|
|
|
|
ArrayAttr constParams = read<ArrayAttr>();
|
|
|
|
Operation *root = read<Operation *>();
|
|
|
|
SmallVector<PDLValue, 16> args;
|
|
|
|
readList<PDLValue>(args);
|
|
|
|
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << " * Root: " << *root << "\n"
|
|
|
|
<< " * Arguments: ";
|
|
|
|
llvm::interleaveComma(args, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
|
|
|
|
});
|
|
|
|
rewriteFn(root, args, constParams, rewriter);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case AreEqual: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
|
|
|
|
const void *lhs = read<const void *>();
|
|
|
|
const void *rhs = read<const void *>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
|
|
|
|
selectJump(lhs == rhs);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case Branch: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
|
|
|
|
curCodeIt = &code[read<ByteCodeAddr>()];
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case CheckOperandCount: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
uint32_t expectedCount = read<uint32_t>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
|
|
|
|
<< " * Expected: " << expectedCount << "\n\n");
|
|
|
|
selectJump(op->getNumOperands() == expectedCount);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case CheckOperationName: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
OperationName expectedName = read<OperationName>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
|
|
<< " * Found: \"" << op->getName() << "\"\n"
|
|
|
|
<< " * Expected: \"" << expectedName << "\"\n\n");
|
|
|
|
selectJump(op->getName() == expectedName);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case CheckResultCount: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
uint32_t expectedCount = read<uint32_t>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
|
|
|
|
<< " * Expected: " << expectedCount << "\n\n");
|
|
|
|
selectJump(op->getNumResults() == expectedCount);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case CreateNative: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
|
|
|
|
const PDLCreateFunction &createFn = createFunctions[read()];
|
|
|
|
ByteCodeField resultIndex = read();
|
|
|
|
ArrayAttr constParams = read<ArrayAttr>();
|
|
|
|
SmallVector<PDLValue, 16> args;
|
|
|
|
readList<PDLValue>(args);
|
|
|
|
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << " * Arguments: ";
|
|
|
|
llvm::interleaveComma(args, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
|
|
|
|
});
|
|
|
|
|
|
|
|
PDLValue result = createFn(args, constParams, rewriter);
|
|
|
|
memory[resultIndex] = result.getAsOpaquePointer();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n");
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case CreateOperation: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
|
|
|
|
assert(mainRewriteLoc && "expected rewrite loc to be provided when "
|
|
|
|
"executing the rewriter bytecode");
|
|
|
|
|
|
|
|
unsigned memIndex = read();
|
|
|
|
OperationState state(*mainRewriteLoc, read<OperationName>());
|
|
|
|
readList<Value>(state.operands);
|
|
|
|
for (unsigned i = 0, e = read(); i != e; ++i) {
|
|
|
|
Identifier name = read<Identifier>();
|
|
|
|
if (Attribute attr = read<Attribute>())
|
|
|
|
state.addAttribute(name, attr);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool hasInferredTypes = false;
|
|
|
|
for (unsigned i = 0, e = read(); i != e; ++i) {
|
|
|
|
Type resultType = read<Type>();
|
|
|
|
hasInferredTypes |= !resultType;
|
|
|
|
state.types.push_back(resultType);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle the case where the operation has inferred types.
|
|
|
|
if (hasInferredTypes) {
|
|
|
|
InferTypeOpInterface::Concept *concept =
|
|
|
|
state.name.getAbstractOperation()
|
|
|
|
->getInterface<InferTypeOpInterface>();
|
|
|
|
|
|
|
|
// TODO: Handle failure.
|
|
|
|
SmallVector<Type, 2> inferredTypes;
|
|
|
|
if (failed(concept->inferReturnTypes(
|
|
|
|
state.getContext(), state.location, state.operands,
|
|
|
|
state.attributes.getDictionary(state.getContext()),
|
|
|
|
state.regions, inferredTypes)))
|
|
|
|
return;
|
|
|
|
|
|
|
|
for (unsigned i = 0, e = state.types.size(); i != e; ++i)
|
|
|
|
if (!state.types[i])
|
|
|
|
state.types[i] = inferredTypes[i];
|
|
|
|
}
|
|
|
|
Operation *resultOp = rewriter.createOperation(state);
|
|
|
|
memory[memIndex] = resultOp;
|
|
|
|
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << " * Attributes: "
|
|
|
|
<< state.attributes.getDictionary(state.getContext())
|
|
|
|
<< "\n * Operands: ";
|
|
|
|
llvm::interleaveComma(state.operands, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n * Result Types: ";
|
|
|
|
llvm::interleaveComma(state.types, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n";
|
|
|
|
});
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case EraseOp: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n");
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case Finalize: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case GetAttribute: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
|
|
|
|
unsigned memIndex = read();
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
Identifier attrName = read<Identifier>();
|
|
|
|
Attribute attr = op->getAttr(attrName);
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
|
|
|
|
<< " * Attribute: " << attrName << "\n"
|
|
|
|
<< " * Result: " << attr << "\n\n");
|
|
|
|
memory[memIndex] = attr.getAsOpaquePointer();
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case GetAttributeType: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
|
|
|
|
unsigned memIndex = read();
|
|
|
|
Attribute attr = read<Attribute>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
|
|
|
|
<< " * Result: " << attr.getType() << "\n\n");
|
|
|
|
memory[memIndex] = attr.getType().getAsOpaquePointer();
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case GetDefiningOp: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
|
|
|
|
unsigned memIndex = read();
|
|
|
|
Value value = read<Value>();
|
|
|
|
Operation *op = value ? value.getDefiningOp() : nullptr;
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
|
|
|
|
<< " * Result: " << *op << "\n\n");
|
|
|
|
memory[memIndex] = op;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case GetOperand0:
|
|
|
|
case GetOperand1:
|
|
|
|
case GetOperand2:
|
|
|
|
case GetOperand3:
|
|
|
|
case GetOperandN: {
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << "Executing GetOperand"
|
|
|
|
<< (opCode == GetOperandN ? Twine("N")
|
|
|
|
: Twine(opCode - GetOperand0))
|
|
|
|
<< ":\n";
|
|
|
|
});
|
|
|
|
unsigned index =
|
|
|
|
opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
unsigned memIndex = read();
|
|
|
|
Value operand =
|
|
|
|
index < op->getNumOperands() ? op->getOperand(index) : Value();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
|
|
|
|
<< " * Index: " << index << "\n"
|
|
|
|
<< " * Result: " << operand << "\n\n");
|
|
|
|
memory[memIndex] = operand.getAsOpaquePointer();
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case GetResult0:
|
|
|
|
case GetResult1:
|
|
|
|
case GetResult2:
|
|
|
|
case GetResult3:
|
|
|
|
case GetResultN: {
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << "Executing GetResult"
|
|
|
|
<< (opCode == GetResultN ? Twine("N")
|
|
|
|
: Twine(opCode - GetResult0))
|
|
|
|
<< ":\n";
|
|
|
|
});
|
|
|
|
unsigned index =
|
|
|
|
opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
unsigned memIndex = read();
|
|
|
|
OpResult result =
|
|
|
|
index < op->getNumResults() ? op->getResult(index) : OpResult();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
|
|
|
|
<< " * Index: " << index << "\n"
|
|
|
|
<< " * Result: " << result << "\n\n");
|
|
|
|
memory[memIndex] = result.getAsOpaquePointer();
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case GetValueType: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
|
|
|
|
unsigned memIndex = read();
|
|
|
|
Value value = read<Value>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
|
|
|
|
<< " * Result: " << value.getType() << "\n\n");
|
|
|
|
memory[memIndex] = value.getType().getAsOpaquePointer();
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case IsNotNull: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
|
|
|
|
const void *value = read<const void *>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n");
|
|
|
|
selectJump(value != nullptr);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case RecordMatch: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
|
|
|
|
assert(matches &&
|
|
|
|
"expected matches to be provided when executing the matcher");
|
|
|
|
unsigned patternIndex = read();
|
|
|
|
PatternBenefit benefit = currentPatternBenefits[patternIndex];
|
|
|
|
const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
|
|
|
|
|
|
|
|
// If the benefit of the pattern is impossible, skip the processing of the
|
|
|
|
// rest of the pattern.
|
|
|
|
if (benefit.isImpossibleToMatch()) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n");
|
|
|
|
curCodeIt = dest;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a fused location containing the locations of each of the
|
|
|
|
// operations used in the match. This will be used as the location for
|
|
|
|
// created operations during the rewrite that don't already have an
|
|
|
|
// explicit location set.
|
|
|
|
unsigned numMatchLocs = read();
|
|
|
|
SmallVector<Location, 4> matchLocs;
|
|
|
|
matchLocs.reserve(numMatchLocs);
|
|
|
|
for (unsigned i = 0; i != numMatchLocs; ++i)
|
|
|
|
matchLocs.push_back(read<Operation *>()->getLoc());
|
|
|
|
Location matchLoc = rewriter.getFusedLoc(matchLocs);
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
|
|
|
|
<< " * Location: " << matchLoc << "\n\n");
|
|
|
|
matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
|
|
|
|
readList<const void *>(matches->back().values);
|
|
|
|
curCodeIt = dest;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case ReplaceOp: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
SmallVector<Value, 16> args;
|
|
|
|
readList<Value>(args);
|
|
|
|
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << " * Operation: " << *op << "\n"
|
|
|
|
<< " * Values: ";
|
|
|
|
llvm::interleaveComma(args, llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n\n";
|
|
|
|
});
|
|
|
|
rewriter.replaceOp(op, args);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case SwitchAttribute: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
|
|
|
|
Attribute value = read<Attribute>();
|
|
|
|
ArrayAttr cases = read<ArrayAttr>();
|
|
|
|
handleSwitch(value, cases);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case SwitchOperandCount: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
|
|
|
|
handleSwitch(op->getNumOperands(), cases);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case SwitchOperationName: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
|
|
|
|
OperationName value = read<Operation *>()->getName();
|
|
|
|
size_t caseCount = read();
|
|
|
|
|
|
|
|
// The operation names are stored in-line, so to print them out for
|
|
|
|
// debugging purposes we need to read the array before executing the
|
|
|
|
// switch so that we can display all of the possible values.
|
|
|
|
LLVM_DEBUG({
|
|
|
|
const ByteCodeField *prevCodeIt = curCodeIt;
|
|
|
|
llvm::dbgs() << " * Value: " << value << "\n"
|
|
|
|
<< " * Cases: ";
|
|
|
|
llvm::interleaveComma(
|
|
|
|
llvm::map_range(llvm::seq<size_t>(0, caseCount),
|
|
|
|
[&](size_t i) { return read<OperationName>(); }),
|
|
|
|
llvm::dbgs());
|
|
|
|
llvm::dbgs() << "\n\n";
|
|
|
|
curCodeIt = prevCodeIt;
|
|
|
|
});
|
|
|
|
|
|
|
|
// Try to find the switch value within any of the cases.
|
|
|
|
size_t jumpDest = 0;
|
|
|
|
for (size_t i = 0; i != caseCount; ++i) {
|
|
|
|
if (read<OperationName>() == value) {
|
|
|
|
curCodeIt += (caseCount - i - 1);
|
|
|
|
jumpDest = i + 1;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
selectJump(jumpDest);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case SwitchResultCount: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
|
|
|
|
Operation *op = read<Operation *>();
|
|
|
|
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
|
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
|
|
|
|
handleSwitch(op->getNumResults(), cases);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case SwitchType: {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
|
|
|
|
Type value = read<Type>();
|
|
|
|
auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
|
|
|
|
handleSwitch(value, cases);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Run the pattern matcher on the given root operation, collecting the matched
|
|
|
|
/// patterns in `matches`.
|
|
|
|
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
|
|
|
|
SmallVectorImpl<MatchResult> &matches,
|
|
|
|
PDLByteCodeMutableState &state) const {
|
|
|
|
// The first memory slot is always the root operation.
|
|
|
|
state.memory[0] = op;
|
|
|
|
|
|
|
|
// The matcher function always starts at code address 0.
|
|
|
|
ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
|
|
|
|
matcherByteCode, state.currentPatternBenefits,
|
|
|
|
patterns, constraintFunctions, createFunctions,
|
|
|
|
rewriteFunctions);
|
|
|
|
executor.execute(rewriter, &matches);
|
|
|
|
|
|
|
|
// Order the found matches by benefit.
|
|
|
|
std::stable_sort(matches.begin(), matches.end(),
|
|
|
|
[](const MatchResult &lhs, const MatchResult &rhs) {
|
|
|
|
return lhs.benefit > rhs.benefit;
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Run the rewriter of the given pattern on the root operation `op`.
|
|
|
|
void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
|
|
|
|
PDLByteCodeMutableState &state) const {
|
|
|
|
// The arguments of the rewrite function are stored at the start of the
|
|
|
|
// memory buffer.
|
|
|
|
llvm::copy(match.values, state.memory.begin());
|
|
|
|
|
|
|
|
ByteCodeExecutor executor(
|
|
|
|
&rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
|
|
|
|
uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
|
|
|
|
constraintFunctions, createFunctions, rewriteFunctions);
|
|
|
|
executor.execute(rewriter, /*matches=*/nullptr, match.location);
|
|
|
|
}
|