Rewrite the DialectOpConversion patterns to inherit from RewritePattern instead of Pattern. This simplifies the infrastructure a bit by being able to reuse PatternRewriter and the RewritePatternMatcher, but also starts to lay the groundwork for a more generalized legalization framework that can operate on DialectOpConversions as well as normal RewritePatterns.

--

PiperOrigin-RevId: 248836492
This commit is contained in:
River Riddle 2019-05-17 22:21:13 -07:00 committed by Mehdi Amini
parent b5ecbb7fd6
commit 3de0c7696b
13 changed files with 404 additions and 531 deletions

View File

@ -28,7 +28,9 @@ class DialectConversion;
class DialectOpConversion;
class MLIRContext;
class Module;
class RewritePattern;
class Type;
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
namespace LLVM {
class LLVMType;
} // end namespace LLVM
@ -39,22 +41,18 @@ namespace linalg {
/// Keep all other types unmodified.
mlir::Type convertLinalgType(mlir::Type t);
/// Allocate the conversion patterns for RangeOp, ViewOp and SliceOp from the
/// Linalg dialect to the LLVM IR dialect. The converters are allocated in the
/// `allocator` using the provided `context`. The latter must have the LLVM IR
/// dialect registered.
/// This function can be used to apply multiple conversion patterns in the same
/// pass. It does not have to be called explicitly before the conversion.
llvm::DenseSet<mlir::DialectOpConversion *>
allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
/// Get the conversion patterns for RangeOp, ViewOp and SliceOp from the Linalg
/// dialect to the LLVM IR dialect. The LLVM IR dialect must be registered. This
/// function can be used to apply multiple conversion patterns in the same pass.
/// It does not have to be called explicitly before the conversion.
void getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context);
/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect.
/// The conversion is set up to convert types and function signatures using
/// `convertLinalgType` and obtains operation converters by calling `initer`.
std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
std::function<void(mlir::OwningRewritePatternList &, mlir::MLIRContext *)>
initer);
/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations

View File

@ -146,14 +146,8 @@ public:
explicit RangeOpConversion(MLIRContext *context)
: DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
if (isa<linalg::RangeOp>(op))
return matchSuccess();
return matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto rangeOp = cast<linalg::RangeOp>(op);
auto rangeDescriptorType =
linalg::convertLinalgType(rangeOp.getResult()->getType());
@ -169,7 +163,7 @@ public:
operands[1], makePositionAttr(rewriter, 1));
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
operands[2], makePositionAttr(rewriter, 2));
return {rangeDescriptor};
rewriter.replaceOp(op, rangeDescriptor);
}
};
@ -178,14 +172,8 @@ public:
explicit ViewOpConversion(MLIRContext *context)
: DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
if (isa<linalg::ViewOp>(op))
return matchSuccess();
return matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto viewOp = cast<linalg::ViewOp>(op);
auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
auto memrefType =
@ -301,7 +289,7 @@ public:
++i;
}
return {viewDescriptor};
rewriter.replaceOp(op, viewDescriptor);
}
};
@ -310,14 +298,8 @@ public:
explicit SliceOpConversion(MLIRContext *context)
: DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
if (isa<linalg::SliceOp>(op))
return matchSuccess();
return matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto sliceOp = cast<linalg::SliceOp>(op);
auto newViewDescriptorType =
linalg::convertLinalgType(sliceOp.getViewType());
@ -398,7 +380,7 @@ public:
stride, pos({3, i}));
}
return {newViewDescriptor};
rewriter.replaceOp(op, newViewDescriptor);
}
};
@ -409,41 +391,30 @@ public:
explicit DropConsumer(MLIRContext *context)
: DialectOpConversion("some_consumer", 1, context) {}
PatternMatchResult match(Operation *op) const override {
if (op->getName().getStringRef() == "some_consumer")
return matchSuccess();
return matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *, ArrayRef<Value *>,
FuncBuilder &) const override {
return {};
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {}
};
llvm::DenseSet<mlir::DialectOpConversion *>
linalg::allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context) {
return ConversionListBuilder<DropConsumer, RangeOpConversion,
SliceOpConversion,
ViewOpConversion>::build(allocator, context);
ConversionListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
ViewOpConversion>::build(patterns, context);
}
namespace {
// The conversion class from Linalg to LLVMIR.
class Lowering : public DialectConversion {
public:
explicit Lowering(std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
explicit Lowering(std::function<void(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context)>
conversions)
: setup(conversions) {}
protected:
// Initialize the list of converters.
llvm::DenseSet<DialectOpConversion *>
initConverters(MLIRContext *context) override {
converterStorage.Reset();
return setup(&converterStorage, context);
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *context) override {
setup(patterns, context);
}
// This gets called for block and region arguments, and attributes.
@ -475,19 +446,15 @@ protected:
}
private:
// Storage for individual converters.
llvm::BumpPtrAllocator converterStorage;
// Conversion setup.
std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
std::function<void(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context)>
setup;
};
} // end anonymous namespace
std::unique_ptr<mlir::DialectConversion> linalg::makeLinalgToLLVMLowering(
std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
std::function<void(mlir::OwningRewritePatternList &, mlir::MLIRContext *)>
initer) {
return llvm::make_unique<Lowering>(initer);
}
@ -502,7 +469,7 @@ void linalg::convertToLLVM(mlir::Module &module) {
// Convert Linalg ops to the LLVM IR dialect using the converter defined
// above.
auto r = Lowering(allocateDescriptorConverters).convert(&module);
auto r = Lowering(getDescriptorConverters).convert(&module);
(void)r;
assert(succeeded(r) && "conversion failed");

View File

@ -104,15 +104,15 @@ public:
// an LLVM IR load.
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
Value *element = intrinsics::load(elementType, ptr);
return {element};
rewriter.replaceOp(op, {element});
}
};
@ -120,15 +120,14 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
// an LLVM IR store.
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
Value *viewDescriptor = operands[1];
Value *data = operands[0];
ArrayRef<Value *> indices = operands.drop_front(2);
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
intrinsics::store(data, ptr);
return {};
}
};
@ -136,15 +135,11 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
// Helper function that allocates the descriptor converters and adds load/store
// coverters to the list.
static llvm::DenseSet<mlir::DialectOpConversion *>
allocateConversions(llvm::BumpPtrAllocator *allocator,
mlir::MLIRContext *context) {
auto conversions = linalg::allocateDescriptorConverters(allocator, context);
auto additional =
ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build(
allocator, context);
conversions.insert(additional.begin(), additional.end());
return conversions;
static void getConversions(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context) {
linalg::getDescriptorConverters(patterns, context);
ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
context);
}
void linalg::convertLinalg3ToLLVM(Module &module) {
@ -155,7 +150,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
auto lowering = makeLinalgToLLVMLowering(allocateConversions);
auto lowering = makeLinalgToLLVMLowering(getConversions);
auto r = lowering->convert(&module);
(void)r;
assert(succeeded(r) && "conversion failed");

View File

@ -86,8 +86,8 @@ public:
explicit MulOpConversion(MLIRContext *context)
: DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
using namespace edsc;
using intrinsics::constant_index;
using linalg::intrinsics::range;
@ -115,7 +115,7 @@ public:
auto rhsView = view(rhs, {r1, r2});
auto resultView = view(result, {r0, r2});
rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
return {typeCast(rewriter, result, mul.getType())};
rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())});
}
};
@ -123,20 +123,16 @@ public:
class EarlyLowering : public DialectConversion {
protected:
// Initialize the list of converters.
llvm::DenseSet<DialectOpConversion *>
initConverters(MLIRContext *context) override {
return ConversionListBuilder<MulOpConversion>::build(&allocator, context);
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *context) override {
ConversionListBuilder<MulOpConversion>::build(patterns, context);
}
private:
llvm::BumpPtrAllocator allocator;
};
/// This is lowering to Linalg the parts that are computationally intensive
/// (like matmul for example...) while keeping the rest of the code in the Toy
/// dialect.
struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
void runOnModule() override {
if (failed(EarlyLowering().convert(&getModule()))) {
getModule().getContext()->emitError(

View File

@ -91,8 +91,8 @@ public:
/// the rewritten operands for `op` in the new function.
/// The results created by the new IR with the builder are returned, and their
/// number must match the number of result of `op`.
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto add = cast<toy::AddOp>(op);
auto loc = add.getLoc();
// Create a `toy.alloc` operation to allocate the output buffer for this op.
@ -119,7 +119,7 @@ public:
// Return the newly allocated buffer, with a type.cast to preserve the
// consumers.
return {typeCast(rewriter, result, add.getType())};
rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())});
}
};
@ -130,8 +130,8 @@ public:
explicit PrintOpConversion(MLIRContext *context)
: DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
Function *printfFunc = getPrintf(*op->getFunction()->getModule());
@ -175,7 +175,6 @@ public:
});
// clang-format on
}
return {};
}
private:
@ -232,8 +231,8 @@ public:
explicit ConstantOpConversion(MLIRContext *context)
: DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
auto loc = cstOp.getLoc();
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
@ -265,7 +264,7 @@ public:
constant_float(value, f64Ty);
}
}
return {result};
rewriter.replaceOp(op, result);
}
};
@ -275,8 +274,8 @@ public:
explicit TransposeOpConversion(MLIRContext *context)
: DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto transpose = cast<toy::TransposeOp>(op);
auto loc = transpose.getLoc();
Value *result = memRefTypeCast(
@ -297,7 +296,7 @@ public:
});
// clang-format on
return {typeCast(rewriter, result, transpose.getType())};
rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())});
}
};
@ -307,8 +306,8 @@ public:
explicit ReturnOpConversion(MLIRContext *context)
: DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto retOp = cast<toy::ReturnOp>(op);
using namespace edsc;
auto loc = retOp.getLoc();
@ -317,7 +316,6 @@ public:
rewriter.create<ReturnOp>(loc, operands[0]);
else
rewriter.create<ReturnOp>(loc);
return {};
}
};
@ -326,31 +324,25 @@ public:
class LateLowering : public DialectConversion {
protected:
/// Initialize the list of converters.
llvm::DenseSet<DialectOpConversion *>
initConverters(MLIRContext *context) override {
return ConversionListBuilder<AddOpConversion, PrintOpConversion,
ConstantOpConversion, TransposeOpConversion,
ReturnOpConversion>::build(&allocator,
context);
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *context) override {
ConversionListBuilder<AddOpConversion, PrintOpConversion,
ConstantOpConversion, TransposeOpConversion,
ReturnOpConversion>::build(patterns, context);
}
/// Convert a Toy type, this gets called for block and region arguments, and
/// attributes.
Type convertType(Type t) override {
if (auto array = t.cast<toy::ToyArrayType>()) {
if (auto array = t.cast<toy::ToyArrayType>())
return array.toMemref();
}
return t;
}
private:
llvm::BumpPtrAllocator allocator;
};
/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
/// and is targeting LLVM otherwise.
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
void runOnModule() override {
// Perform Toy specific lowering
if (failed(LateLowering().convert(&getModule()))) {

View File

@ -248,8 +248,8 @@ public:
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead = {});
virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead = {});
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
@ -326,14 +326,12 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
///
class RewritePatternMatcher {
public:
/// Create a RewritePatternMatcher with the specified set of patterns and
/// rewriter.
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns,
PatternRewriter &rewriter);
/// Create a RewritePatternMatcher with the specified set of patterns.
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
/// Try to match the given operation to a pattern and rewrite it. Return
/// true if any pattern matches.
bool matchAndRewrite(Operation *op);
bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
private:
RewritePatternMatcher(const RewritePatternMatcher &) = delete;
@ -342,9 +340,6 @@ private:
/// The group of patterns that are matched for optimization through this
/// matcher.
OwningRewritePatternList patterns;
/// The rewriter used when applying matched patterns.
PatternRewriter &rewriter;
};
/// Rewrite the specified function by repeatedly applying the highest benefit

View File

@ -56,18 +56,15 @@ public:
llvm::LLVMContext &getLLVMContext();
protected:
/// Create a set of converters that live in the pass object by passing them a
/// reference to the LLVM IR dialect. Store the module associated with the
/// dialect for further type conversion.
llvm::DenseSet<DialectOpConversion *>
initConverters(MLIRContext *mlirContext) override final;
/// Add a set of converters to the given pattern list. Store the module
/// associated with the dialect for further type conversion.
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *mlirContext) override final;
/// Derived classes can override this function to initialize custom converters
/// in addition to the existing converters from Standard operations. It will
/// be called after the `module` and `llvmDialect` have been made available.
virtual llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() {
return {};
};
virtual void initAdditionalConverters(OwningRewritePatternList &patterns) {}
/// Derived classes can override this function to convert custom types. It
/// will be called by convertType if the default conversion from standard and

View File

@ -42,63 +42,65 @@ class FunctionConversion;
}
/// Base class for the dialect op conversion patterns. Specific conversions
/// must derive this class and implement least one of `rewrite` and
/// `rewriteTerminator`. Optionally they can also override
/// `PatternMatch match(Operation *)` to match more specific operations than the
/// `rootName` provided in the constructor.
//
// TODO(zinenko): this should eventually converge with RewritePattern. So far,
// rewritePattern is missing support for operations with successors as well as
// an ability to accept new operands instead of reusing those of the existing
// operation.
class DialectOpConversion : public Pattern {
/// must derive this class and implement least one `rewrite` method. Optionally
/// they can also override `PatternMatch match(Operation *)` to match more
/// specific operations than the `rootName` provided in the constructor.
/// NOTE: These conversion patterns can only be used with the DialectConversion
/// class.
class DialectOpConversion : public RewritePattern {
public:
/// Construct an DialectOpConversion. `rootName` must correspond to the
/// canonical name of the first operation matched by the pattern.
DialectOpConversion(StringRef rootName, PatternBenefit benefit,
MLIRContext *ctx)
: Pattern(rootName, benefit, ctx) {}
: RewritePattern(rootName, benefit, ctx) {}
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// Hook for derived classes to implement matching. Dialect conversion
/// generally unconditionally match the root operation, so default to success
/// here.
virtual PatternMatchResult match(Operation *op) const override {
return matchSuccess();
}
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `operands` is a list of rewritten values
/// that are passed to this operation, `rewriter` can be used to emit the new
/// operations. This function returns the values produced by the newly
/// created operation(s). These values will be used instead of those produced
/// by the original operation. This function must be reimplemented if the
/// DialectOpConversion ever needs to replace an operation that does not have
/// successors. This function should not fail. If some specific cases of the
/// operation are not supported, these cases should not be matched.
virtual SmallVector<Value *, 4> rewrite(Operation *op,
ArrayRef<Value *> operands,
FuncBuilder &rewriter) const {
llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
};
/// operations. This function must be reimplemented if the DialectOpConversion
/// ever needs to replace an operation that does not have successors. This
/// function should not fail. If some specific cases of the operation are not
/// supported, these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `properOperands` is a list of rewritten
/// values that are passed to the operation itself, `destinations` is a list
/// of (potentially rewritten) successor blocks, `operands` is a list of lists
/// of rewritten values passed to each of the successors, co-indexed with
/// `destinations`, `rewriter` can be used to emit the new operations. Since
/// terminators never produce results (which could not be used anyway), this
/// function does not return anything. It must be reimplemented if the
/// DialectOpConversion ever needs to replace a terminator operation that has
/// successors. This function should not fail the pass. If some specific
/// cases of the operation are not supported, these cases should not be
/// matched.
virtual void rewriteTerminator(Operation *op,
ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
FuncBuilder &rewriter) const {
llvm_unreachable("unimplemented rewriteTerminator, did you mean rewrite?");
/// `destinations`, `rewriter` can be used to emit the new operations. It must
/// be reimplemented if the DialectOpConversion ever needs to replace a
/// terminator operation that has successors. This function should not fail
/// the pass. If some specific cases of the operation are not supported,
/// these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
PatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}
/// Provide a default implementation for matching: most DialectOpConversion
/// implementations are unconditionally matching.
PatternMatchResult match(Operation *op) const override {
return matchSuccess();
}
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
void rewrite(Operation *op, PatternRewriter &rewriter) const final;
private:
using RewritePattern::matchAndRewrite;
using RewritePattern::rewrite;
};
// Helper class to create a list of dialect conversion patterns given a list of
@ -106,27 +108,22 @@ public:
// conversion constructors.
template <typename Arg, typename... Args> struct ConversionListBuilder {
template <typename... ConstructorArgs>
static llvm::DenseSet<DialectOpConversion *>
build(llvm::BumpPtrAllocator *allocator,
ConstructorArgs &&... constructorArgs) {
auto sub = ConversionListBuilder<Args...>::build(
allocator, std::forward<ConstructorArgs>(constructorArgs)...);
auto *ptr = allocator->Allocate<Arg>();
new (ptr) Arg(std::forward<ConstructorArgs>(constructorArgs)...);
sub.insert(ptr);
return sub;
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
ConversionListBuilder<Args...>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
ConversionListBuilder<Arg>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
}
};
// Template specialization to stop recursion.
template <typename Arg> struct ConversionListBuilder<Arg> {
template <typename... ConstructorArgs>
static llvm::DenseSet<DialectOpConversion *>
build(llvm::BumpPtrAllocator *allocator,
ConstructorArgs &&... constructorArgs) {
auto *ptr = allocator->Allocate<Arg>();
new (ptr) Arg(std::forward<ConstructorArgs>(constructorArgs)...);
return {ptr};
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
patterns.emplace_back(llvm::make_unique<Arg>(
std::forward<ConstructorArgs>(constructorArgs)...));
}
};
@ -138,19 +135,17 @@ template <typename Arg> struct ConversionListBuilder<Arg> {
/// current MLIR context.
/// 2. For each function in the module do the following.
// a. Create a new function with the same name and convert its signature
// using `convertType`.
// using `convertType`.
// b. For each block in the function, create a block in the function with
// its arguments converted using `convertType`.
// its arguments converted using `convertType`.
// c. Traverse blocks in DFS-preorder of successors starting from the entry
// block (if any), and convert individual operations as follows. Pattern
// match against the list of conversions. On the first match, call
// `rewriteTerminator` for terminator operations with successors and
// `rewrite` for other operations, and advance to the next iteration. If no
// match is found, replicate the operation as is. Note that if two patterns
// match the same operation, it is undefined which of them will be applied.
// block (if any), and convert individual operations as follows. Pattern
// match against the list of conversions. On the first match, call
// `rewrite` for the operations, and advance to the next iteration. If no
// match is found, replicate the operation as is.
/// 3. Update all attributes of function type to point to the new functions.
/// 4. Replace old functions with new functions in the module.
/// If any error happend during the conversion, the pass fails as soon as
/// If any error happened during the conversion, the pass fails as soon as
/// possible.
///
/// If the conversion fails, the module is not modified.
@ -167,9 +162,10 @@ public:
protected:
/// Derived classes must implement this hook to produce a set of conversion
/// patterns to apply. They may use `mlirContext` to obtain registered
/// dialects or operations. This will be called in the beginning of the pass.
virtual llvm::DenseSet<DialectOpConversion *>
initConverters(MLIRContext *mlirContext) = 0;
/// dialects or operations. This will be called in the beginning of the
/// conversion.
virtual void initConverters(OwningRewritePatternList &patterns,
MLIRContext *mlirContext) = 0;
/// Derived classes must reimplement this hook if they need to convert
/// block or function argument types or function result types. If the target

View File

@ -123,8 +123,8 @@ void PatternRewriter::updatedRootInPlace(
//===----------------------------------------------------------------------===//
RewritePatternMatcher::RewritePatternMatcher(
OwningRewritePatternList &&patterns, PatternRewriter &rewriter)
: patterns(std::move(patterns)), rewriter(rewriter) {
OwningRewritePatternList &&patterns)
: patterns(std::move(patterns)) {
// Sort the patterns by benefit to simplify the matching logic.
std::stable_sort(this->patterns.begin(), this->patterns.end(),
[](const std::unique_ptr<RewritePattern> &l,
@ -134,7 +134,8 @@ RewritePatternMatcher::RewritePatternMatcher(
}
/// Try to match the given operation to a pattern and rewrite it.
bool RewritePatternMatcher::matchAndRewrite(Operation *op) {
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
for (auto &pattern : patterns) {
// Ignore patterns that are for the wrong root or are impossible to match.
if (pattern->getRootKind() != op->getName() ||

View File

@ -212,11 +212,8 @@ public:
lowering_),
dialect(dialect_) {}
// Match by type.
PatternMatchResult match(Operation *op) const override {
if (isa<SourceOp>(op))
return this->matchSuccess();
return this->matchFailure();
return this->matchSuccess();
}
// Get the LLVM IR dialect.
@ -241,7 +238,7 @@ public:
}
// Create an LLVM IR pseudo-operation defining the given index constant.
Value *createIndexConstant(FuncBuilder &builder, Location loc,
Value *createIndexConstant(PatternRewriter &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
@ -249,7 +246,7 @@ public:
// Get the array attribute named "position" containing the given list of
// integers as integer attribute elements.
static ArrayAttr getIntegerArrayAttr(FuncBuilder &builder,
static ArrayAttr getIntegerArrayAttr(PatternRewriter &builder,
ArrayRef<int64_t> values) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(values.size());
@ -259,7 +256,7 @@ public:
}
// Extract raw data pointer value from a value representing a memref.
static Value *extractMemRefElementPtr(FuncBuilder &builder, Location loc,
static Value *extractMemRefElementPtr(PatternRewriter &builder, Location loc,
Value *convertedMemRefValue,
Type elementTypePtr,
bool hasStaticShape) {
@ -297,8 +294,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
Type packedType;
@ -314,9 +311,9 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return {};
return;
if (numResults == 1)
return {newOp.getOperation()->getResult(0)};
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0));
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
@ -328,7 +325,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
op->getLoc(), type, newOp.getOperation()->getResult(0),
this->getIntegerArrayAttr(rewriter, i)));
}
return results;
rewriter.replaceOp(op, results);
}
};
@ -417,15 +414,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<AllocOp>::match(op))
return matchFailure();
auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
MemRefType type = cast<AllocOp>(op).getType();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
@ -449,7 +443,6 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
op->getLoc(), getIndexType(),
ArrayRef<Value *>{cumulativeSize, sizes[i]});
// Compute the total amount of bytes to allocate.
auto elementType = type.getElementType();
assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
@ -493,9 +486,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
ArrayRef<Value *>(allocated));
// Deal with static memrefs
if (numOperands == 0) {
return {allocated};
}
if (numOperands == 0)
return rewriter.replaceOp(op, allocated);
// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
@ -515,7 +507,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
}
// Return the final value of the descriptor.
return {memRefDescriptor};
rewriter.replaceOp(op, memRefDescriptor);
}
};
@ -525,8 +517,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
// Insert the `free` declaration if it is not already present.
@ -552,7 +544,6 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
op->getLoc(), getVoidPtrType(), bufferPtr);
rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(),
rewriter.getFunctionAttr(freeFunc), casted);
return {};
}
};
@ -560,8 +551,6 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<MemRefCastOp>::match(op))
return matchFailure();
auto memRefCastOp = cast<MemRefCastOp>(op);
MemRefType sourceType =
memRefCastOp.getOperand()->getType().cast<MemRefType>();
@ -572,8 +561,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
: matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
auto targetType = memRefCastOp.getType();
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
@ -584,9 +573,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
elementTypePtr, sourceType.hasStaticShape());
// Account for static memrefs as target types
if (targetType.hasStaticShape()) {
return {buffer};
}
if (targetType.hasStaticShape())
return rewriter.replaceOp(op, buffer);
// Create the new MemRef descriptor.
auto structType = lowering.convertType(targetType);
@ -629,7 +617,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
"target dynamic dimensions were not set up");
return {newDescriptor};
rewriter.replaceOp(op, newDescriptor);
}
};
@ -639,20 +627,17 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<DimOp>::match(op))
return this->matchFailure();
auto dimOp = cast<DimOp>(op);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "expected exactly one operand");
auto dimOp = cast<DimOp>(op);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
SmallVector<Value *, 4> results;
auto shape = type.getShape();
uint64_t index = dimOp.getIndex();
// Extract dynamic size from the memref descriptor and define static size
@ -666,14 +651,13 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
if (shape[i] == -1)
++position;
}
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(), operands[0],
getIntegerArrayAttr(rewriter, position)));
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), operands[0],
getIntegerArrayAttr(rewriter, position));
} else {
results.push_back(
createIndexConstant(rewriter, op->getLoc(), shape[index]));
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
}
return results;
}
};
@ -686,10 +670,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using Base = LoadStoreOpLowering<Derived>;
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<Derived>::match(op))
return this->matchFailure();
auto loadOp = cast<Derived>(op);
MemRefType type = loadOp.getMemRefType();
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? this->matchSuccess()
: this->matchFailure();
}
@ -702,7 +683,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value *linearizeSubscripts(FuncBuilder &builder, Location loc,
Value *linearizeSubscripts(PatternRewriter &builder, Location loc,
ArrayRef<Value *> indices,
ArrayRef<Value *> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
@ -727,7 +708,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// indices.
Value *getElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *memRefDescriptor,
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
ArrayRef<Value *> indices,
PatternRewriter &rewriter) const {
// Get the list of MemRef sizes. Static sizes are defined as constants.
// Dynamic sizes are extracted from the MemRef descriptor, where they start
// from the position 1 (the buffer is at position 0).
@ -763,7 +745,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
Value *getRawElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *rawDataPtr,
ArrayRef<Value *> indices,
FuncBuilder &rewriter) const {
PatternRewriter &rewriter) const {
if (shape.empty())
return rawDataPtr;
@ -779,16 +761,15 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
}
Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
ArrayRef<Value *> indices, FuncBuilder &rewriter,
ArrayRef<Value *> indices, PatternRewriter &rewriter,
llvm::Module &module) const {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
auto shape = type.getShape();
if (type.hasStaticShape()) {
// NB: If memref was statically-shaped, dataPtr is pointer to raw data.
return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
} else {
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
}
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
}
};
@ -797,8 +778,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
auto type = loadOp.getMemRefType();
@ -806,10 +787,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
operands.drop_front(), rewriter, getModule());
auto elementType = lowering.convertType(type.getElementType());
SmallVector<Value *, 4> results;
results.push_back(rewriter.create<LLVM::LoadOp>(
op->getLoc(), elementType, ArrayRef<Value *>{dataPtr}));
return results;
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
ArrayRef<Value *>{dataPtr});
}
};
@ -818,16 +797,13 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto storeOp = cast<StoreOp>(op);
auto type = storeOp.getMemRefType();
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
operands.drop_front(2), rewriter, getModule());
rewriter.create<LLVM::StoreOp>(op->getLoc(), operands[0], dataPtr);
return {};
}
};
@ -838,10 +814,10 @@ struct OneToOneLLVMTerminatorLowering
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
void rewriteTerminator(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
PatternRewriter &rewriter) const override {
rewriter.create<TargetOp>(op->getLoc(), properOperands, destinations,
operands, op->getAttrs());
}
@ -856,8 +832,8 @@ struct OneToOneLLVMTerminatorLowering
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
// If ReturnOp has 0 or 1 operand, create it and return immediately.
@ -865,14 +841,14 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
rewriter.create<LLVM::ReturnOp>(
op->getLoc(), llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
return {};
return;
}
if (numArguments == 1) {
rewriter.create<LLVM::ReturnOp>(
op->getLoc(), llvm::ArrayRef<Value *>(operands.front()),
llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
op->getAttrs());
return {};
return;
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
@ -888,7 +864,6 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
rewriter.create<LLVM::ReturnOp>(
op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
return {};
}
};
@ -954,20 +929,19 @@ void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
// Create a set of converters that live in the pass object by passing them a
// reference to the LLVM IR dialect. Store the module associated with the
// dialect for further type conversion.
llvm::DenseSet<DialectOpConversion *>
LLVMLowering::initConverters(MLIRContext *mlirContext) {
converterStorage.Reset();
void LLVMLowering::initConverters(OwningRewritePatternList &patterns,
MLIRContext *mlirContext) {
llvmDialect = mlirContext->getRegisteredDialect<LLVM::LLVMDialect>();
if (!llvmDialect) {
mlirContext->emitError(UnknownLoc::get(mlirContext),
"LLVM IR dialect is not registered");
return {};
return;
}
module = &llvmDialect->getLLVMModule();
// FIXME: this should be tablegen'ed
auto converters = ConversionListBuilder<
ConversionListBuilder<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
@ -975,11 +949,8 @@ LLVMLowering::initConverters(MLIRContext *mlirContext) {
LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect,
*this);
auto extraConverters = initAdditionalConverters();
converters.insert(extraConverters.begin(), extraConverters.end());
return converters;
SubIOpLowering, XOrOpLowering>::build(patterns, *llvmDialect, *this);
initAdditionalConverters(patterns);
}
// Convert types using the stored LLVM IR module.

View File

@ -147,7 +147,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
// Create an array attribute containing integer attributes with values provided
// in `position`.
static ArrayAttr positionAttr(FuncBuilder &builder, ArrayRef<int> position) {
static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(position.size());
for (auto p : position)
@ -162,8 +162,8 @@ public:
LLVMLowering &lowering_)
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy = LLVM::LLVMType::get(
op->getContext(),
@ -207,7 +207,7 @@ public:
positionAttr(rewriter, 0));
desc = insertvalue(bufferDescriptorType, desc, size,
positionAttr(rewriter, 1));
return {desc};
rewriter.replaceOp(op, desc);
}
};
@ -219,8 +219,8 @@ public:
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto voidPtrTy = LLVM::LLVMType::get(
op->getContext(),
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
@ -243,8 +243,6 @@ public:
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
positionAttr(rewriter, 0)));
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
return {};
}
};
@ -254,11 +252,12 @@ public:
BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto int64Ty = lowering.convertType(operands[0]->getType());
edsc::ScopedContext context(rewriter, op->getLoc());
return {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))};
rewriter.replaceOp(
op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
}
};
@ -268,14 +267,16 @@ public:
explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto dimOp = cast<linalg::DimOp>(op);
auto indexTy = lowering.convertType(rewriter.getIndexType());
edsc::ScopedContext context(rewriter, op->getLoc());
return {extractvalue(
indexTy, operands[0],
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
rewriter.replaceOp(
op,
{extractvalue(
indexTy, operands[0],
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
}
};
@ -293,7 +294,8 @@ public:
// descriptor to emit IR iteratively computing the actual offset, followed by
// a getelementptr. This must be called under an edsc::ScopedContext.
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
ArrayRef<Value *> indices,
PatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
auto elementTy = rewriter.getType<LLVMType>(
getPtrToElementType(loadOp.getViewType(), lowering));
@ -320,15 +322,14 @@ public:
// an LLVM IR load.
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementTy = lowering.convertType(*op->getResultTypes().begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
Value *element = llvm_load(elementTy, ptr);
return {element};
rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
}
};
@ -338,8 +339,8 @@ public:
explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeOp.getResult()->getType(), lowering);
@ -354,8 +355,7 @@ public:
positionAttr(rewriter, 1));
desc = insertvalue(rangeDescriptorTy, desc, operands[2],
positionAttr(rewriter, 2));
return {desc};
rewriter.replaceOp(op, desc);
}
};
@ -367,8 +367,8 @@ public:
: LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto rangeIntersectOp = cast<RangeIntersectOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
@ -400,8 +400,7 @@ public:
// TODO(ntv): this assumes both steps are one for now. Enforce and extend.
desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2),
positionAttr(rewriter, 2));
return {desc};
rewriter.replaceOp(op, desc);
}
};
@ -410,8 +409,8 @@ public:
explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_)
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto sliceOp = cast<SliceOp>(op);
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
@ -483,7 +482,7 @@ public:
++i;
}
return {desc};
rewriter.replaceOp(op, desc);
}
};
@ -491,15 +490,14 @@ public:
// an LLVM IR store.
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
Value *data = operands[0];
Value *viewDescriptor = operands[1];
ArrayRef<Value *> indices = operands.drop_front(2);
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
llvm_store(data, ptr);
return {};
}
};
@ -508,8 +506,8 @@ public:
explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto viewOp = cast<ViewOp>(op);
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
auto elementTy = rewriter.getType<LLVMType>(
@ -556,7 +554,7 @@ public:
runningStride = mul(runningStride, max);
}
return {desc};
rewriter.replaceOp(op, desc);
}
};
@ -568,18 +566,20 @@ public:
static StringRef libraryFunctionName() { return "linalg_dot"; }
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto *f =
op->getFunction()->getModule()->getNamedFunction(libraryFunctionName());
if (!f)
if (!f) {
op->emitError("Could not find function: " + libraryFunctionName() +
"in lowering to LLVM ");
return;
}
auto fAttr = rewriter.getFunctionAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
rewriter.create<LLVM::CallOp>(op->getLoc(), operands,
ArrayRef<NamedAttribute>{named});
return {};
}
};
@ -587,14 +587,13 @@ namespace {
// The conversion class from Linalg to LLVMIR.
class Lowering : public LLVMLowering {
protected:
llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
void initAdditionalConverters(OwningRewritePatternList &patterns) override {
return ConversionListBuilder<
BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion,
SliceOpConversion, StoreOpConversion,
ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
*this);
ViewOpConversion>::build(patterns, llvmDialect->getContext(), *this);
}
Type convertAdditionalType(Type t) override {

View File

@ -28,6 +28,84 @@
using namespace mlir;
namespace {
/// This class implements a pattern rewriter for DialectOpConversion patterns.
/// It automatically performs remapping of replaced operation values.
struct DialectConversionRewriter final : public PatternRewriter {
DialectConversionRewriter(Function *fn) : PatternRewriter(fn) {}
~DialectConversionRewriter() = default;
// Implement the hook for replacing an operation with new values.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) override {
assert(newValues.size() == op->getNumResults());
for (unsigned i = 0, e = newValues.size(); i < e; ++i)
mapping.map(op->getResult(i), newValues[i]);
}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
Operation *createOperation(const OperationState &state) override {
return FuncBuilder::createOperation(state);
}
void lookupValues(Operation::operand_range operands,
SmallVectorImpl<Value *> &remapped) {
remapped.reserve(llvm::size(operands));
for (Value *operand : operands) {
Value *value = mapping.lookupOrNull(operand);
assert(value && "converting op before ops defining its operands");
remapped.push_back(value);
}
}
// Mapping between values(blocks) in the original function and in the new
// function.
BlockAndValueMapping mapping;
};
} // end anonymous namespace
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
void DialectOpConversion::rewrite(Operation *op,
PatternRewriter &rewriter) const {
SmallVector<Value *, 4> operands;
auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
dialectRewriter.lookupValues(op->getOperands(), operands);
// If this operation has no successors, invoke the rewrite directly.
if (op->getNumSuccessors() == 0)
return rewrite(op, operands, rewriter);
// Otherwise, we need to remap the successors.
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
// Lookup the successor.
auto *successor = dialectRewriter.mapping.lookupOrNull(op->getSuccessor(i));
assert(successor && "block was not remapped");
destinations.push_back(successor);
// Lookup the successors operands.
unsigned n = op->getNumSuccessorOperands(i);
operandsPerDestination.push_back(
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
seen += n;
}
// Rewrite the operation.
rewrite(op,
llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand),
destinations, operandsPerDestination, rewriter);
}
namespace mlir {
namespace impl {
// Implementation detail class of the DialectConversion pass. Performs
@ -36,24 +114,14 @@ namespace impl {
// old functions with the new ones in the module.
class FunctionConversion {
public:
// Entry point. Uses hooks defined in `conversion` to obtain the list of
// conversion patterns and to convert function and block argument types.
// Converts the `module` in-place by replacing all existing functions with the
// converted ones.
static LogicalResult convert(DialectConversion *conversion, Module *module);
private:
// Constructs a FunctionConversion by storing the hooks.
explicit FunctionConversion(DialectConversion *conversion)
: dialectConversion(conversion) {}
explicit FunctionConversion(DialectConversion *conversion, Function *func,
RewritePatternMatcher &matcher)
: dialectConversion(conversion), rewriter(func), matcher(matcher) {}
// Utility that looks up a list of value in the value remapping table. Returns
// an empty vector if one of the values is not mapped yet.
SmallVector<Value *, 4> lookupValues(Operation::operand_range operands);
// Converts the given function to the dialect using hooks defined in
// Converts the current function to the dialect using hooks defined in
// `dialectConversion`. Returns the converted function or `nullptr` on error.
Function *convertFunction(Function *f);
Function *convertFunction();
// Converts the given region starting from the entry block and following the
// block successors. Returns the converted region or `nullptr` on error.
@ -61,19 +129,6 @@ private:
std::unique_ptr<Region> convertRegion(MLIRContext *context, Region *region,
RegionParent *parent);
// Converts an operation with successors. Extracts the converted operands
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
// passes them to `converter->rewriteTerminator` function defined in the
// pattern, together with `builder`.
LogicalResult convertOpWithSuccessors(DialectOpConversion *converter,
Operation *op, FuncBuilder &builder);
// Converts an operation without successors. Extracts the converted operands
// from `valueRemapping` and passes them to the `converter->rewrite` function
// defined in the pattern, together with `builder`.
LogicalResult convertOp(DialectOpConversion *converter, Operation *op,
FuncBuilder &builder);
// Converts a block by traversing its operations sequentially, looking for
// the first pattern match and dispatching the operation conversion to
// either `convertOp` or `convertOpWithSuccessors` depending on the presence
@ -81,128 +136,40 @@ private:
//
// After converting operations, traverses the successor blocks unless they
// have been visited already as indicated in `visitedBlocks`.
LogicalResult convertBlock(Block *block, FuncBuilder &builder,
LogicalResult convertBlock(Block *block,
llvm::DenseSet<Block *> &visitedBlocks);
// Converts the module as follows.
// 1. Call `convertFunction` on each function of the module and collect the
// mapping between old and new functions.
// 2. Remap all function attributes in the new functions to point to the new
// functions instead of the old ones.
// 3. Replace old functions with the new in the module.
LogicalResult run(Module *m);
// Pointer to a specific dialect pass.
DialectConversion *dialectConversion;
// Set of known conversion patterns.
llvm::DenseSet<DialectOpConversion *> conversions;
/// The writer used when rewriting operations.
DialectConversionRewriter rewriter;
// Mapping between values(blocks) in the original function and in the new
// function.
BlockAndValueMapping mapping;
/// The matcher use when converting operations.
RewritePatternMatcher &matcher;
};
} // end namespace impl
} // end namespace mlir
SmallVector<Value *, 4>
impl::FunctionConversion::lookupValues(Operation::operand_range operands) {
SmallVector<Value *, 4> remapped;
remapped.reserve(llvm::size(operands));
for (Value *operand : operands) {
Value *value = mapping.lookupOrNull(operand);
if (!value)
return {};
remapped.push_back(value);
}
return remapped;
}
LogicalResult impl::FunctionConversion::convertOpWithSuccessors(
DialectOpConversion *converter, Operation *op, FuncBuilder &builder) {
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) &&
"converting op before ops defining its operands");
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
unsigned numSuccessorOperands = 0;
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i)
numSuccessorOperands += op->getNumSuccessorOperands(i);
unsigned seen = 0;
unsigned firstSuccessorOperand = op->getNumOperands() - numSuccessorOperands;
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) {
Block *successor = mapping.lookupOrNull(op->getSuccessor(i));
assert(successor && "block was not remapped");
destinations.push_back(successor);
unsigned n = op->getNumSuccessorOperands(i);
operandsPerDestination.push_back(
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
seen += n;
}
converter->rewriteTerminator(
op,
llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand),
destinations, operandsPerDestination, builder);
return success();
}
LogicalResult
impl::FunctionConversion::convertOp(DialectOpConversion *converter,
Operation *op, FuncBuilder &builder) {
auto operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) &&
"converting op before ops defining its operands");
auto results = converter->rewrite(op, operands, builder);
if (results.size() != op->getNumResults())
return op->emitError("rewriting produced a different number of results");
for (unsigned i = 0, e = results.size(); i < e; ++i)
mapping.map(op->getResult(i), results[i]);
return success();
}
LogicalResult
impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
impl::FunctionConversion::convertBlock(Block *block,
llvm::DenseSet<Block *> &visitedBlocks) {
// First, add the current block to the list of visited blocks.
visitedBlocks.insert(block);
// Setup the builder to the insert to the converted block.
builder.setInsertionPointToStart(mapping.lookupOrNull(block));
rewriter.setInsertionPointToStart(rewriter.mapping.lookupOrNull(block));
// Iterate over ops and convert them.
for (Operation &op : *block) {
// Find the first matching conversion and apply it.
bool converted = false;
for (auto *conversion : conversions) {
// Ignore patterns that are for the wrong root or are impossible to match.
if (conversion->getRootKind() != op.getName() ||
conversion->getBenefit().isImpossibleToMatch())
continue;
if (matcher.matchAndRewrite(&op, rewriter))
continue;
if (!conversion->match(&op))
continue;
if (op.getNumSuccessors() != 0) {
if (failed(convertOpWithSuccessors(conversion, &op, builder)))
return failure();
} else if (failed(convertOp(conversion, &op, builder))) {
return failure();
}
converted = true;
break;
}
// If there is no conversion provided for the op, clone the op and convert
// its regions, if any.
if (!converted) {
auto *newOp = builder.cloneWithoutRegions(op, mapping);
for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
newOp->getRegion(i).takeBody(*newRegion);
}
auto *newOp = rewriter.cloneWithoutRegions(op, rewriter.mapping);
for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
newOp->getRegion(i).takeBody(*newRegion);
}
}
@ -210,7 +177,7 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
for (Block *succ : block->getSuccessors()) {
if (visitedBlocks.count(succ) != 0)
continue;
if (failed(convertBlock(succ, builder, visitedBlocks)))
if (failed(convertBlock(succ, visitedBlocks)))
return failure();
}
return success();
@ -234,33 +201,31 @@ impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region,
for (Block &block : *region) {
auto *newBlock = new Block;
newRegion->push_back(newBlock);
mapping.map(&block, newBlock);
rewriter.mapping.map(&block, newBlock);
for (auto *arg : block.getArguments()) {
auto convertedType = dialectConversion->convertType(arg->getType());
if (!convertedType)
return emitError("could not convert block argument type");
newBlock->addArgument(convertedType);
mapping.map(arg, *newBlock->args_rbegin());
rewriter.mapping.map(arg, *newBlock->args_rbegin());
}
}
// Start a DFS-order traversal of the CFG to make sure defs are converted
// before uses in dominated blocks.
llvm::DenseSet<Block *> visitedBlocks;
FuncBuilder builder(&newRegion->front());
if (failed(convertBlock(&region->front(), builder, visitedBlocks)))
if (failed(convertBlock(&region->front(), visitedBlocks)))
return nullptr;
// If some blocks are not reachable through successor chains, they should have
// been removed by the DCE before this.
if (visitedBlocks.size() != std::distance(region->begin(), region->end()))
return emitError("unreachable blocks were not converted");
return newRegion;
}
Function *impl::FunctionConversion::convertFunction(Function *f) {
assert(f && "expected function");
Function *impl::FunctionConversion::convertFunction() {
Function *f = rewriter.getFunction();
MLIRContext *context = f->getContext();
auto emitError = [context](llvm::Twine f) -> Function * {
context->emitError(UnknownLoc::get(context), f.str());
@ -278,8 +243,8 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
f->getLoc(), f->getName().strref(), newFunctionType.cast<FunctionType>(),
f->getAttrs(), newFunctionArgAttrs);
// Return early if the function has no blocks.
if (f->getBlocks().empty())
// Return early if the function is external.
if (f->isExternal())
return newFunction.release();
auto newBody = convertRegion(context, &f->getBody(), f);
@ -290,54 +255,6 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
return newFunction.release();
}
LogicalResult impl::FunctionConversion::convert(DialectConversion *conversion,
Module *module) {
return impl::FunctionConversion(conversion).run(module);
}
LogicalResult impl::FunctionConversion::run(Module *module) {
if (!module)
return failure();
MLIRContext *context = module->getContext();
conversions = dialectConversion->initConverters(context);
// Convert the functions but don't add them to the module yet to avoid
// converted functions to be converted again.
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
originalFuncs.reserve(module->getFunctions().size());
for (auto &func : *module)
originalFuncs.push_back(&func);
convertedFuncs.reserve(module->getFunctions().size());
for (auto *func : originalFuncs) {
Function *converted = convertFunction(func);
if (!converted)
return failure();
auto origFuncAttr = FunctionAttr::get(func);
auto convertedFuncAttr = FunctionAttr::get(converted);
convertedFuncs.push_back(converted);
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
}
// Remap function attributes in the converted functions (they are not yet in
// the module). Original functions will disappear anyway so there is no
// need to remap attributes in them.
for (const auto &funcPair : functionAttrRemapping) {
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
}
// Remove original functions from the module, then insert converted
// functions. The order is important to avoid name collisions.
for (auto &func : originalFuncs)
func->erase();
for (auto *func : convertedFuncs)
module->getFunctions().push_back(func);
return success();
}
// Create a function type with arguments and results converted, and argument
// attributes passed through.
FunctionType DialectConversion::convertFunctionSignatureType(
@ -363,6 +280,55 @@ FunctionType DialectConversion::convertFunctionSignatureType(
return FunctionType::get(arguments, results, type.getContext());
}
LogicalResult DialectConversion::convert(Module *m) {
return impl::FunctionConversion::convert(this, m);
// Converts the module as follows.
// 1. Call `convertFunction` on each function of the module and collect the
// mapping between old and new functions.
// 2. Remap all function attributes in the new functions to point to the new
// functions instead of the old ones.
// 3. Replace old functions with the new in the module.
LogicalResult DialectConversion::convert(Module *module) {
if (!module)
return failure();
// Grab the conversion patterns from the converter and create the pattern
// matcher.
MLIRContext *context = module->getContext();
OwningRewritePatternList patterns;
initConverters(patterns, context);
RewritePatternMatcher matcher(std::move(patterns));
// Convert the functions but don't add them to the module yet to avoid
// converted functions to be converted again.
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
originalFuncs.reserve(module->getFunctions().size());
for (auto &func : *module)
originalFuncs.push_back(&func);
convertedFuncs.reserve(module->getFunctions().size());
for (auto *func : originalFuncs) {
impl::FunctionConversion converter(this, func, matcher);
Function *converted = converter.convertFunction();
if (!converted)
return failure();
auto origFuncAttr = FunctionAttr::get(func);
auto convertedFuncAttr = FunctionAttr::get(converted);
convertedFuncs.push_back(converted);
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
}
// Remap function attributes in the converted functions (they are not yet in
// the module). Original functions will disappear anyway so there is no
// need to remap attributes in them.
for (const auto &funcPair : functionAttrRemapping)
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
// Remove original functions from the module, then insert converted
// functions. The order is important to avoid name collisions.
for (auto &func : originalFuncs)
func->erase();
for (auto *func : convertedFuncs)
module->getFunctions().push_back(func);
return success();
}

View File

@ -46,7 +46,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(Function &fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(&fn), matcher(std::move(patterns), *this) {
: PatternRewriter(&fn), matcher(std::move(patterns)) {
worklist.reserve(64);
}
@ -202,7 +202,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
// Try to match one of the canonicalization patterns. The rewriter is
// automatically notified of any necessary changes, so there is nothing
// else to do here.
changed |= matcher.matchAndRewrite(op);
changed |= matcher.matchAndRewrite(op, *this);
}
} while (changed && ++i < maxIterations);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.