forked from OSchip/llvm-project
Rewrite the DialectOpConversion patterns to inherit from RewritePattern instead of Pattern. This simplifies the infrastructure a bit by being able to reuse PatternRewriter and the RewritePatternMatcher, but also starts to lay the groundwork for a more generalized legalization framework that can operate on DialectOpConversions as well as normal RewritePatterns.
-- PiperOrigin-RevId: 248836492
This commit is contained in:
parent
b5ecbb7fd6
commit
3de0c7696b
|
@ -28,7 +28,9 @@ class DialectConversion;
|
|||
class DialectOpConversion;
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
class RewritePattern;
|
||||
class Type;
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
namespace LLVM {
|
||||
class LLVMType;
|
||||
} // end namespace LLVM
|
||||
|
@ -39,22 +41,18 @@ namespace linalg {
|
|||
/// Keep all other types unmodified.
|
||||
mlir::Type convertLinalgType(mlir::Type t);
|
||||
|
||||
/// Allocate the conversion patterns for RangeOp, ViewOp and SliceOp from the
|
||||
/// Linalg dialect to the LLVM IR dialect. The converters are allocated in the
|
||||
/// `allocator` using the provided `context`. The latter must have the LLVM IR
|
||||
/// dialect registered.
|
||||
/// This function can be used to apply multiple conversion patterns in the same
|
||||
/// pass. It does not have to be called explicitly before the conversion.
|
||||
llvm::DenseSet<mlir::DialectOpConversion *>
|
||||
allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
|
||||
/// Get the conversion patterns for RangeOp, ViewOp and SliceOp from the Linalg
|
||||
/// dialect to the LLVM IR dialect. The LLVM IR dialect must be registered. This
|
||||
/// function can be used to apply multiple conversion patterns in the same pass.
|
||||
/// It does not have to be called explicitly before the conversion.
|
||||
void getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
|
||||
mlir::MLIRContext *context);
|
||||
|
||||
/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect.
|
||||
/// The conversion is set up to convert types and function signatures using
|
||||
/// `convertLinalgType` and obtains operation converters by calling `initer`.
|
||||
std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
|
||||
std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
|
||||
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
|
||||
std::function<void(mlir::OwningRewritePatternList &, mlir::MLIRContext *)>
|
||||
initer);
|
||||
|
||||
/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
|
||||
|
|
|
@ -146,14 +146,8 @@ public:
|
|||
explicit RangeOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (isa<linalg::RangeOp>(op))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto rangeOp = cast<linalg::RangeOp>(op);
|
||||
auto rangeDescriptorType =
|
||||
linalg::convertLinalgType(rangeOp.getResult()->getType());
|
||||
|
@ -169,7 +163,7 @@ public:
|
|||
operands[1], makePositionAttr(rewriter, 1));
|
||||
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
||||
operands[2], makePositionAttr(rewriter, 2));
|
||||
return {rangeDescriptor};
|
||||
rewriter.replaceOp(op, rangeDescriptor);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -178,14 +172,8 @@ public:
|
|||
explicit ViewOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (isa<linalg::ViewOp>(op))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto viewOp = cast<linalg::ViewOp>(op);
|
||||
auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
|
||||
auto memrefType =
|
||||
|
@ -301,7 +289,7 @@ public:
|
|||
++i;
|
||||
}
|
||||
|
||||
return {viewDescriptor};
|
||||
rewriter.replaceOp(op, viewDescriptor);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -310,14 +298,8 @@ public:
|
|||
explicit SliceOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (isa<linalg::SliceOp>(op))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto sliceOp = cast<linalg::SliceOp>(op);
|
||||
auto newViewDescriptorType =
|
||||
linalg::convertLinalgType(sliceOp.getViewType());
|
||||
|
@ -398,7 +380,7 @@ public:
|
|||
stride, pos({3, i}));
|
||||
}
|
||||
|
||||
return {newViewDescriptor};
|
||||
rewriter.replaceOp(op, newViewDescriptor);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -409,41 +391,30 @@ public:
|
|||
explicit DropConsumer(MLIRContext *context)
|
||||
: DialectOpConversion("some_consumer", 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (op->getName().getStringRef() == "some_consumer")
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *, ArrayRef<Value *>,
|
||||
FuncBuilder &) const override {
|
||||
return {};
|
||||
}
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {}
|
||||
};
|
||||
|
||||
llvm::DenseSet<mlir::DialectOpConversion *>
|
||||
linalg::allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
|
||||
void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
return ConversionListBuilder<DropConsumer, RangeOpConversion,
|
||||
SliceOpConversion,
|
||||
ViewOpConversion>::build(allocator, context);
|
||||
ConversionListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
|
||||
ViewOpConversion>::build(patterns, context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// The conversion class from Linalg to LLVMIR.
|
||||
class Lowering : public DialectConversion {
|
||||
public:
|
||||
explicit Lowering(std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
|
||||
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
|
||||
explicit Lowering(std::function<void(mlir::OwningRewritePatternList &patterns,
|
||||
mlir::MLIRContext *context)>
|
||||
conversions)
|
||||
: setup(conversions) {}
|
||||
|
||||
protected:
|
||||
// Initialize the list of converters.
|
||||
llvm::DenseSet<DialectOpConversion *>
|
||||
initConverters(MLIRContext *context) override {
|
||||
converterStorage.Reset();
|
||||
return setup(&converterStorage, context);
|
||||
void initConverters(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context) override {
|
||||
setup(patterns, context);
|
||||
}
|
||||
|
||||
// This gets called for block and region arguments, and attributes.
|
||||
|
@ -475,19 +446,15 @@ protected:
|
|||
}
|
||||
|
||||
private:
|
||||
// Storage for individual converters.
|
||||
llvm::BumpPtrAllocator converterStorage;
|
||||
|
||||
// Conversion setup.
|
||||
std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
|
||||
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
|
||||
std::function<void(mlir::OwningRewritePatternList &patterns,
|
||||
mlir::MLIRContext *context)>
|
||||
setup;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::DialectConversion> linalg::makeLinalgToLLVMLowering(
|
||||
std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
|
||||
llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
|
||||
std::function<void(mlir::OwningRewritePatternList &, mlir::MLIRContext *)>
|
||||
initer) {
|
||||
return llvm::make_unique<Lowering>(initer);
|
||||
}
|
||||
|
@ -502,7 +469,7 @@ void linalg::convertToLLVM(mlir::Module &module) {
|
|||
|
||||
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
||||
// above.
|
||||
auto r = Lowering(allocateDescriptorConverters).convert(&module);
|
||||
auto r = Lowering(getDescriptorConverters).convert(&module);
|
||||
(void)r;
|
||||
assert(succeeded(r) && "conversion failed");
|
||||
|
||||
|
|
|
@ -104,15 +104,15 @@ public:
|
|||
// an LLVM IR load.
|
||||
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
||||
using Base::Base;
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
|
||||
Value *viewDescriptor = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front();
|
||||
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
Value *element = intrinsics::load(elementType, ptr);
|
||||
return {element};
|
||||
rewriter.replaceOp(op, {element});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -120,15 +120,14 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
|||
// an LLVM IR store.
|
||||
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
|
||||
using Base::Base;
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
Value *viewDescriptor = operands[1];
|
||||
Value *data = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front(2);
|
||||
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
intrinsics::store(data, ptr);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -136,15 +135,11 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
|
|||
|
||||
// Helper function that allocates the descriptor converters and adds load/store
|
||||
// coverters to the list.
|
||||
static llvm::DenseSet<mlir::DialectOpConversion *>
|
||||
allocateConversions(llvm::BumpPtrAllocator *allocator,
|
||||
mlir::MLIRContext *context) {
|
||||
auto conversions = linalg::allocateDescriptorConverters(allocator, context);
|
||||
auto additional =
|
||||
ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build(
|
||||
allocator, context);
|
||||
conversions.insert(additional.begin(), additional.end());
|
||||
return conversions;
|
||||
static void getConversions(mlir::OwningRewritePatternList &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
linalg::getDescriptorConverters(patterns, context);
|
||||
ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
|
||||
context);
|
||||
}
|
||||
|
||||
void linalg::convertLinalg3ToLLVM(Module &module) {
|
||||
|
@ -155,7 +150,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
|
|||
(void)rr;
|
||||
assert(succeeded(rr) && "affine loop lowering failed");
|
||||
|
||||
auto lowering = makeLinalgToLLVMLowering(allocateConversions);
|
||||
auto lowering = makeLinalgToLLVMLowering(getConversions);
|
||||
auto r = lowering->convert(&module);
|
||||
(void)r;
|
||||
assert(succeeded(r) && "conversion failed");
|
||||
|
|
|
@ -86,8 +86,8 @@ public:
|
|||
explicit MulOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
using namespace edsc;
|
||||
using intrinsics::constant_index;
|
||||
using linalg::intrinsics::range;
|
||||
|
@ -115,7 +115,7 @@ public:
|
|||
auto rhsView = view(rhs, {r1, r2});
|
||||
auto resultView = view(result, {r0, r2});
|
||||
rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
|
||||
return {typeCast(rewriter, result, mul.getType())};
|
||||
rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -123,20 +123,16 @@ public:
|
|||
class EarlyLowering : public DialectConversion {
|
||||
protected:
|
||||
// Initialize the list of converters.
|
||||
llvm::DenseSet<DialectOpConversion *>
|
||||
initConverters(MLIRContext *context) override {
|
||||
return ConversionListBuilder<MulOpConversion>::build(&allocator, context);
|
||||
void initConverters(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context) override {
|
||||
ConversionListBuilder<MulOpConversion>::build(patterns, context);
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
|
||||
/// This is lowering to Linalg the parts that are computationally intensive
|
||||
/// (like matmul for example...) while keeping the rest of the code in the Toy
|
||||
/// dialect.
|
||||
struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
|
||||
|
||||
void runOnModule() override {
|
||||
if (failed(EarlyLowering().convert(&getModule()))) {
|
||||
getModule().getContext()->emitError(
|
||||
|
|
|
@ -91,8 +91,8 @@ public:
|
|||
/// the rewritten operands for `op` in the new function.
|
||||
/// The results created by the new IR with the builder are returned, and their
|
||||
/// number must match the number of result of `op`.
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto add = cast<toy::AddOp>(op);
|
||||
auto loc = add.getLoc();
|
||||
// Create a `toy.alloc` operation to allocate the output buffer for this op.
|
||||
|
@ -119,7 +119,7 @@ public:
|
|||
|
||||
// Return the newly allocated buffer, with a type.cast to preserve the
|
||||
// consumers.
|
||||
return {typeCast(rewriter, result, add.getType())};
|
||||
rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -130,8 +130,8 @@ public:
|
|||
explicit PrintOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Get or create the declaration of the printf function in the module.
|
||||
Function *printfFunc = getPrintf(*op->getFunction()->getModule());
|
||||
|
||||
|
@ -175,7 +175,6 @@ public:
|
|||
});
|
||||
// clang-format on
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -232,8 +231,8 @@ public:
|
|||
explicit ConstantOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
|
||||
auto loc = cstOp.getLoc();
|
||||
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
|
||||
|
@ -265,7 +264,7 @@ public:
|
|||
constant_float(value, f64Ty);
|
||||
}
|
||||
}
|
||||
return {result};
|
||||
rewriter.replaceOp(op, result);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -275,8 +274,8 @@ public:
|
|||
explicit TransposeOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto transpose = cast<toy::TransposeOp>(op);
|
||||
auto loc = transpose.getLoc();
|
||||
Value *result = memRefTypeCast(
|
||||
|
@ -297,7 +296,7 @@ public:
|
|||
});
|
||||
// clang-format on
|
||||
|
||||
return {typeCast(rewriter, result, transpose.getType())};
|
||||
rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -307,8 +306,8 @@ public:
|
|||
explicit ReturnOpConversion(MLIRContext *context)
|
||||
: DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto retOp = cast<toy::ReturnOp>(op);
|
||||
using namespace edsc;
|
||||
auto loc = retOp.getLoc();
|
||||
|
@ -317,7 +316,6 @@ public:
|
|||
rewriter.create<ReturnOp>(loc, operands[0]);
|
||||
else
|
||||
rewriter.create<ReturnOp>(loc);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -326,31 +324,25 @@ public:
|
|||
class LateLowering : public DialectConversion {
|
||||
protected:
|
||||
/// Initialize the list of converters.
|
||||
llvm::DenseSet<DialectOpConversion *>
|
||||
initConverters(MLIRContext *context) override {
|
||||
return ConversionListBuilder<AddOpConversion, PrintOpConversion,
|
||||
ConstantOpConversion, TransposeOpConversion,
|
||||
ReturnOpConversion>::build(&allocator,
|
||||
context);
|
||||
void initConverters(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context) override {
|
||||
ConversionListBuilder<AddOpConversion, PrintOpConversion,
|
||||
ConstantOpConversion, TransposeOpConversion,
|
||||
ReturnOpConversion>::build(patterns, context);
|
||||
}
|
||||
|
||||
/// Convert a Toy type, this gets called for block and region arguments, and
|
||||
/// attributes.
|
||||
Type convertType(Type t) override {
|
||||
if (auto array = t.cast<toy::ToyArrayType>()) {
|
||||
if (auto array = t.cast<toy::ToyArrayType>())
|
||||
return array.toMemref();
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
|
||||
/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
|
||||
/// and is targeting LLVM otherwise.
|
||||
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
||||
|
||||
void runOnModule() override {
|
||||
// Perform Toy specific lowering
|
||||
if (failed(LateLowering().convert(&getModule()))) {
|
||||
|
|
|
@ -248,8 +248,8 @@ public:
|
|||
/// clients can specify a list of other nodes that this replacement may make
|
||||
/// (perhaps transitively) dead. If any of those values are dead, this will
|
||||
/// remove them as well.
|
||||
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead = {});
|
||||
virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead = {});
|
||||
|
||||
/// Replaces the result op with a new op that is created without verification.
|
||||
/// The result values of the two ops must be the same types.
|
||||
|
@ -326,14 +326,12 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
|||
///
|
||||
class RewritePatternMatcher {
|
||||
public:
|
||||
/// Create a RewritePatternMatcher with the specified set of patterns and
|
||||
/// rewriter.
|
||||
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns,
|
||||
PatternRewriter &rewriter);
|
||||
/// Create a RewritePatternMatcher with the specified set of patterns.
|
||||
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
|
||||
|
||||
/// Try to match the given operation to a pattern and rewrite it. Return
|
||||
/// true if any pattern matches.
|
||||
bool matchAndRewrite(Operation *op);
|
||||
bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
|
||||
|
||||
private:
|
||||
RewritePatternMatcher(const RewritePatternMatcher &) = delete;
|
||||
|
@ -342,9 +340,6 @@ private:
|
|||
/// The group of patterns that are matched for optimization through this
|
||||
/// matcher.
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
/// The rewriter used when applying matched patterns.
|
||||
PatternRewriter &rewriter;
|
||||
};
|
||||
|
||||
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||
|
|
|
@ -56,18 +56,15 @@ public:
|
|||
llvm::LLVMContext &getLLVMContext();
|
||||
|
||||
protected:
|
||||
/// Create a set of converters that live in the pass object by passing them a
|
||||
/// reference to the LLVM IR dialect. Store the module associated with the
|
||||
/// dialect for further type conversion.
|
||||
llvm::DenseSet<DialectOpConversion *>
|
||||
initConverters(MLIRContext *mlirContext) override final;
|
||||
/// Add a set of converters to the given pattern list. Store the module
|
||||
/// associated with the dialect for further type conversion.
|
||||
void initConverters(OwningRewritePatternList &patterns,
|
||||
MLIRContext *mlirContext) override final;
|
||||
|
||||
/// Derived classes can override this function to initialize custom converters
|
||||
/// in addition to the existing converters from Standard operations. It will
|
||||
/// be called after the `module` and `llvmDialect` have been made available.
|
||||
virtual llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() {
|
||||
return {};
|
||||
};
|
||||
virtual void initAdditionalConverters(OwningRewritePatternList &patterns) {}
|
||||
|
||||
/// Derived classes can override this function to convert custom types. It
|
||||
/// will be called by convertType if the default conversion from standard and
|
||||
|
|
|
@ -42,63 +42,65 @@ class FunctionConversion;
|
|||
}
|
||||
|
||||
/// Base class for the dialect op conversion patterns. Specific conversions
|
||||
/// must derive this class and implement least one of `rewrite` and
|
||||
/// `rewriteTerminator`. Optionally they can also override
|
||||
/// `PatternMatch match(Operation *)` to match more specific operations than the
|
||||
/// `rootName` provided in the constructor.
|
||||
//
|
||||
// TODO(zinenko): this should eventually converge with RewritePattern. So far,
|
||||
// rewritePattern is missing support for operations with successors as well as
|
||||
// an ability to accept new operands instead of reusing those of the existing
|
||||
// operation.
|
||||
class DialectOpConversion : public Pattern {
|
||||
/// must derive this class and implement least one `rewrite` method. Optionally
|
||||
/// they can also override `PatternMatch match(Operation *)` to match more
|
||||
/// specific operations than the `rootName` provided in the constructor.
|
||||
/// NOTE: These conversion patterns can only be used with the DialectConversion
|
||||
/// class.
|
||||
class DialectOpConversion : public RewritePattern {
|
||||
public:
|
||||
/// Construct an DialectOpConversion. `rootName` must correspond to the
|
||||
/// canonical name of the first operation matched by the pattern.
|
||||
DialectOpConversion(StringRef rootName, PatternBenefit benefit,
|
||||
MLIRContext *ctx)
|
||||
: Pattern(rootName, benefit, ctx) {}
|
||||
: RewritePattern(rootName, benefit, ctx) {}
|
||||
|
||||
/// Hook for derived classes to implement rewriting. `op` is the (first)
|
||||
/// Hook for derived classes to implement matching. Dialect conversion
|
||||
/// generally unconditionally match the root operation, so default to success
|
||||
/// here.
|
||||
virtual PatternMatchResult match(Operation *op) const override {
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
/// Hook for derived classes to implement rewriting. `op` is the (first)
|
||||
/// operation matched by the pattern, `operands` is a list of rewritten values
|
||||
/// that are passed to this operation, `rewriter` can be used to emit the new
|
||||
/// operations. This function returns the values produced by the newly
|
||||
/// created operation(s). These values will be used instead of those produced
|
||||
/// by the original operation. This function must be reimplemented if the
|
||||
/// DialectOpConversion ever needs to replace an operation that does not have
|
||||
/// successors. This function should not fail. If some specific cases of the
|
||||
/// operation are not supported, these cases should not be matched.
|
||||
virtual SmallVector<Value *, 4> rewrite(Operation *op,
|
||||
ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const {
|
||||
llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
|
||||
};
|
||||
/// operations. This function must be reimplemented if the DialectOpConversion
|
||||
/// ever needs to replace an operation that does not have successors. This
|
||||
/// function should not fail. If some specific cases of the operation are not
|
||||
/// supported, these cases should not be matched.
|
||||
virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const {
|
||||
llvm_unreachable("unimplemented rewrite");
|
||||
}
|
||||
|
||||
/// Hook for derived classes to implement rewriting. `op` is the (first)
|
||||
/// Hook for derived classes to implement rewriting. `op` is the (first)
|
||||
/// operation matched by the pattern, `properOperands` is a list of rewritten
|
||||
/// values that are passed to the operation itself, `destinations` is a list
|
||||
/// of (potentially rewritten) successor blocks, `operands` is a list of lists
|
||||
/// of rewritten values passed to each of the successors, co-indexed with
|
||||
/// `destinations`, `rewriter` can be used to emit the new operations. Since
|
||||
/// terminators never produce results (which could not be used anyway), this
|
||||
/// function does not return anything. It must be reimplemented if the
|
||||
/// DialectOpConversion ever needs to replace a terminator operation that has
|
||||
/// successors. This function should not fail the pass. If some specific
|
||||
/// cases of the operation are not supported, these cases should not be
|
||||
/// matched.
|
||||
virtual void rewriteTerminator(Operation *op,
|
||||
ArrayRef<Value *> properOperands,
|
||||
ArrayRef<Block *> destinations,
|
||||
ArrayRef<ArrayRef<Value *>> operands,
|
||||
FuncBuilder &rewriter) const {
|
||||
llvm_unreachable("unimplemented rewriteTerminator, did you mean rewrite?");
|
||||
/// `destinations`, `rewriter` can be used to emit the new operations. It must
|
||||
/// be reimplemented if the DialectOpConversion ever needs to replace a
|
||||
/// terminator operation that has successors. This function should not fail
|
||||
/// the pass. If some specific cases of the operation are not supported,
|
||||
/// these cases should not be matched.
|
||||
virtual void rewrite(Operation *op, ArrayRef<Value *> properOperands,
|
||||
ArrayRef<Block *> destinations,
|
||||
ArrayRef<ArrayRef<Value *>> operands,
|
||||
PatternRewriter &rewriter) const {
|
||||
llvm_unreachable("unimplemented rewrite for terminators");
|
||||
}
|
||||
|
||||
/// Provide a default implementation for matching: most DialectOpConversion
|
||||
/// implementations are unconditionally matching.
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
return matchSuccess();
|
||||
}
|
||||
/// Rewrite the IR rooted at the specified operation with the result of
|
||||
/// this pattern, generating any new operations with the specified
|
||||
/// builder. If an unexpected error is encountered (an internal
|
||||
/// compiler error), it is emitted through the normal MLIR diagnostic
|
||||
/// hooks and the IR is left in a valid state.
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const final;
|
||||
|
||||
private:
|
||||
using RewritePattern::matchAndRewrite;
|
||||
using RewritePattern::rewrite;
|
||||
};
|
||||
|
||||
// Helper class to create a list of dialect conversion patterns given a list of
|
||||
|
@ -106,27 +108,22 @@ public:
|
|||
// conversion constructors.
|
||||
template <typename Arg, typename... Args> struct ConversionListBuilder {
|
||||
template <typename... ConstructorArgs>
|
||||
static llvm::DenseSet<DialectOpConversion *>
|
||||
build(llvm::BumpPtrAllocator *allocator,
|
||||
ConstructorArgs &&... constructorArgs) {
|
||||
auto sub = ConversionListBuilder<Args...>::build(
|
||||
allocator, std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
auto *ptr = allocator->Allocate<Arg>();
|
||||
new (ptr) Arg(std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
sub.insert(ptr);
|
||||
return sub;
|
||||
static void build(OwningRewritePatternList &patterns,
|
||||
ConstructorArgs &&... constructorArgs) {
|
||||
ConversionListBuilder<Args...>::build(
|
||||
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
ConversionListBuilder<Arg>::build(
|
||||
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
}
|
||||
};
|
||||
|
||||
// Template specialization to stop recursion.
|
||||
template <typename Arg> struct ConversionListBuilder<Arg> {
|
||||
template <typename... ConstructorArgs>
|
||||
static llvm::DenseSet<DialectOpConversion *>
|
||||
build(llvm::BumpPtrAllocator *allocator,
|
||||
ConstructorArgs &&... constructorArgs) {
|
||||
auto *ptr = allocator->Allocate<Arg>();
|
||||
new (ptr) Arg(std::forward<ConstructorArgs>(constructorArgs)...);
|
||||
return {ptr};
|
||||
static void build(OwningRewritePatternList &patterns,
|
||||
ConstructorArgs &&... constructorArgs) {
|
||||
patterns.emplace_back(llvm::make_unique<Arg>(
|
||||
std::forward<ConstructorArgs>(constructorArgs)...));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -138,19 +135,17 @@ template <typename Arg> struct ConversionListBuilder<Arg> {
|
|||
/// current MLIR context.
|
||||
/// 2. For each function in the module do the following.
|
||||
// a. Create a new function with the same name and convert its signature
|
||||
// using `convertType`.
|
||||
// using `convertType`.
|
||||
// b. For each block in the function, create a block in the function with
|
||||
// its arguments converted using `convertType`.
|
||||
// its arguments converted using `convertType`.
|
||||
// c. Traverse blocks in DFS-preorder of successors starting from the entry
|
||||
// block (if any), and convert individual operations as follows. Pattern
|
||||
// match against the list of conversions. On the first match, call
|
||||
// `rewriteTerminator` for terminator operations with successors and
|
||||
// `rewrite` for other operations, and advance to the next iteration. If no
|
||||
// match is found, replicate the operation as is. Note that if two patterns
|
||||
// match the same operation, it is undefined which of them will be applied.
|
||||
// block (if any), and convert individual operations as follows. Pattern
|
||||
// match against the list of conversions. On the first match, call
|
||||
// `rewrite` for the operations, and advance to the next iteration. If no
|
||||
// match is found, replicate the operation as is.
|
||||
/// 3. Update all attributes of function type to point to the new functions.
|
||||
/// 4. Replace old functions with new functions in the module.
|
||||
/// If any error happend during the conversion, the pass fails as soon as
|
||||
/// If any error happened during the conversion, the pass fails as soon as
|
||||
/// possible.
|
||||
///
|
||||
/// If the conversion fails, the module is not modified.
|
||||
|
@ -167,9 +162,10 @@ public:
|
|||
protected:
|
||||
/// Derived classes must implement this hook to produce a set of conversion
|
||||
/// patterns to apply. They may use `mlirContext` to obtain registered
|
||||
/// dialects or operations. This will be called in the beginning of the pass.
|
||||
virtual llvm::DenseSet<DialectOpConversion *>
|
||||
initConverters(MLIRContext *mlirContext) = 0;
|
||||
/// dialects or operations. This will be called in the beginning of the
|
||||
/// conversion.
|
||||
virtual void initConverters(OwningRewritePatternList &patterns,
|
||||
MLIRContext *mlirContext) = 0;
|
||||
|
||||
/// Derived classes must reimplement this hook if they need to convert
|
||||
/// block or function argument types or function result types. If the target
|
||||
|
|
|
@ -123,8 +123,8 @@ void PatternRewriter::updatedRootInPlace(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
RewritePatternMatcher::RewritePatternMatcher(
|
||||
OwningRewritePatternList &&patterns, PatternRewriter &rewriter)
|
||||
: patterns(std::move(patterns)), rewriter(rewriter) {
|
||||
OwningRewritePatternList &&patterns)
|
||||
: patterns(std::move(patterns)) {
|
||||
// Sort the patterns by benefit to simplify the matching logic.
|
||||
std::stable_sort(this->patterns.begin(), this->patterns.end(),
|
||||
[](const std::unique_ptr<RewritePattern> &l,
|
||||
|
@ -134,7 +134,8 @@ RewritePatternMatcher::RewritePatternMatcher(
|
|||
}
|
||||
|
||||
/// Try to match the given operation to a pattern and rewrite it.
|
||||
bool RewritePatternMatcher::matchAndRewrite(Operation *op) {
|
||||
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
for (auto &pattern : patterns) {
|
||||
// Ignore patterns that are for the wrong root or are impossible to match.
|
||||
if (pattern->getRootKind() != op->getName() ||
|
||||
|
|
|
@ -212,11 +212,8 @@ public:
|
|||
lowering_),
|
||||
dialect(dialect_) {}
|
||||
|
||||
// Match by type.
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (isa<SourceOp>(op))
|
||||
return this->matchSuccess();
|
||||
return this->matchFailure();
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
// Get the LLVM IR dialect.
|
||||
|
@ -241,7 +238,7 @@ public:
|
|||
}
|
||||
|
||||
// Create an LLVM IR pseudo-operation defining the given index constant.
|
||||
Value *createIndexConstant(FuncBuilder &builder, Location loc,
|
||||
Value *createIndexConstant(PatternRewriter &builder, Location loc,
|
||||
uint64_t value) const {
|
||||
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
|
||||
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
|
||||
|
@ -249,7 +246,7 @@ public:
|
|||
|
||||
// Get the array attribute named "position" containing the given list of
|
||||
// integers as integer attribute elements.
|
||||
static ArrayAttr getIntegerArrayAttr(FuncBuilder &builder,
|
||||
static ArrayAttr getIntegerArrayAttr(PatternRewriter &builder,
|
||||
ArrayRef<int64_t> values) {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
|
@ -259,7 +256,7 @@ public:
|
|||
}
|
||||
|
||||
// Extract raw data pointer value from a value representing a memref.
|
||||
static Value *extractMemRefElementPtr(FuncBuilder &builder, Location loc,
|
||||
static Value *extractMemRefElementPtr(PatternRewriter &builder, Location loc,
|
||||
Value *convertedMemRefValue,
|
||||
Type elementTypePtr,
|
||||
bool hasStaticShape) {
|
||||
|
@ -297,8 +294,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// Convert the type of the result to an LLVM type, pass operands as is,
|
||||
// preserve attributes.
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
unsigned numResults = op->getNumResults();
|
||||
|
||||
Type packedType;
|
||||
|
@ -314,9 +311,9 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// If the operation produced 0 or 1 result, return them immediately.
|
||||
if (numResults == 0)
|
||||
return {};
|
||||
return;
|
||||
if (numResults == 1)
|
||||
return {newOp.getOperation()->getResult(0)};
|
||||
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0));
|
||||
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
// Extract individual results from the structure and return them as list.
|
||||
|
@ -328,7 +325,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
op->getLoc(), type, newOp.getOperation()->getResult(0),
|
||||
this->getIntegerArrayAttr(rewriter, i)));
|
||||
}
|
||||
return results;
|
||||
rewriter.replaceOp(op, results);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -417,15 +414,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (!LLVMLegalizationPattern<AllocOp>::match(op))
|
||||
return matchFailure();
|
||||
auto allocOp = cast<AllocOp>(op);
|
||||
MemRefType type = allocOp.getType();
|
||||
MemRefType type = cast<AllocOp>(op).getType();
|
||||
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto allocOp = cast<AllocOp>(op);
|
||||
MemRefType type = allocOp.getType();
|
||||
|
||||
|
@ -449,7 +443,6 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
op->getLoc(), getIndexType(),
|
||||
ArrayRef<Value *>{cumulativeSize, sizes[i]});
|
||||
|
||||
|
||||
// Compute the total amount of bytes to allocate.
|
||||
auto elementType = type.getElementType();
|
||||
assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
|
||||
|
@ -493,9 +486,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
ArrayRef<Value *>(allocated));
|
||||
|
||||
// Deal with static memrefs
|
||||
if (numOperands == 0) {
|
||||
return {allocated};
|
||||
}
|
||||
if (numOperands == 0)
|
||||
return rewriter.replaceOp(op, allocated);
|
||||
|
||||
// Create the MemRef descriptor.
|
||||
auto structType = lowering.convertType(type);
|
||||
|
@ -515,7 +507,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
}
|
||||
|
||||
// Return the final value of the descriptor.
|
||||
return {memRefDescriptor};
|
||||
rewriter.replaceOp(op, memRefDescriptor);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -525,8 +517,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(operands.size() == 1 && "dealloc takes one operand");
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
|
@ -552,7 +544,6 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
op->getLoc(), getVoidPtrType(), bufferPtr);
|
||||
rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(),
|
||||
rewriter.getFunctionAttr(freeFunc), casted);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -560,8 +551,6 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (!LLVMLegalizationPattern<MemRefCastOp>::match(op))
|
||||
return matchFailure();
|
||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
MemRefType sourceType =
|
||||
memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
||||
|
@ -572,8 +561,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
: matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
auto targetType = memRefCastOp.getType();
|
||||
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
||||
|
@ -584,9 +573,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
|
||||
elementTypePtr, sourceType.hasStaticShape());
|
||||
// Account for static memrefs as target types
|
||||
if (targetType.hasStaticShape()) {
|
||||
return {buffer};
|
||||
}
|
||||
if (targetType.hasStaticShape())
|
||||
return rewriter.replaceOp(op, buffer);
|
||||
|
||||
// Create the new MemRef descriptor.
|
||||
auto structType = lowering.convertType(targetType);
|
||||
|
@ -629,7 +617,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
|
||||
"target dynamic dimensions were not set up");
|
||||
|
||||
return {newDescriptor};
|
||||
rewriter.replaceOp(op, newDescriptor);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -639,20 +627,17 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
|||
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (!LLVMLegalizationPattern<DimOp>::match(op))
|
||||
return this->matchFailure();
|
||||
auto dimOp = cast<DimOp>(op);
|
||||
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
|
||||
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(operands.size() == 1 && "expected exactly one operand");
|
||||
auto dimOp = cast<DimOp>(op);
|
||||
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
|
||||
|
||||
SmallVector<Value *, 4> results;
|
||||
auto shape = type.getShape();
|
||||
uint64_t index = dimOp.getIndex();
|
||||
// Extract dynamic size from the memref descriptor and define static size
|
||||
|
@ -666,14 +651,13 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
|||
if (shape[i] == -1)
|
||||
++position;
|
||||
}
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), getIndexType(), operands[0],
|
||||
getIntegerArrayAttr(rewriter, position)));
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
||||
op, getIndexType(), operands[0],
|
||||
getIntegerArrayAttr(rewriter, position));
|
||||
} else {
|
||||
results.push_back(
|
||||
createIndexConstant(rewriter, op->getLoc(), shape[index]));
|
||||
rewriter.replaceOp(
|
||||
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -686,10 +670,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
using Base = LoadStoreOpLowering<Derived>;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (!LLVMLegalizationPattern<Derived>::match(op))
|
||||
return this->matchFailure();
|
||||
auto loadOp = cast<Derived>(op);
|
||||
MemRefType type = loadOp.getMemRefType();
|
||||
MemRefType type = cast<Derived>(op).getMemRefType();
|
||||
return isSupportedMemRefType(type) ? this->matchSuccess()
|
||||
: this->matchFailure();
|
||||
}
|
||||
|
@ -702,7 +683,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
// by accumulating the running linearized value.
|
||||
// Note that `indices` and `allocSizes` are passed in the same order as they
|
||||
// appear in load/store operations and memref type declarations.
|
||||
Value *linearizeSubscripts(FuncBuilder &builder, Location loc,
|
||||
Value *linearizeSubscripts(PatternRewriter &builder, Location loc,
|
||||
ArrayRef<Value *> indices,
|
||||
ArrayRef<Value *> allocSizes) const {
|
||||
assert(indices.size() == allocSizes.size() &&
|
||||
|
@ -727,7 +708,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
// indices.
|
||||
Value *getElementPtr(Location loc, Type elementTypePtr,
|
||||
ArrayRef<int64_t> shape, Value *memRefDescriptor,
|
||||
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
|
||||
ArrayRef<Value *> indices,
|
||||
PatternRewriter &rewriter) const {
|
||||
// Get the list of MemRef sizes. Static sizes are defined as constants.
|
||||
// Dynamic sizes are extracted from the MemRef descriptor, where they start
|
||||
// from the position 1 (the buffer is at position 0).
|
||||
|
@ -763,7 +745,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
Value *getRawElementPtr(Location loc, Type elementTypePtr,
|
||||
ArrayRef<int64_t> shape, Value *rawDataPtr,
|
||||
ArrayRef<Value *> indices,
|
||||
FuncBuilder &rewriter) const {
|
||||
PatternRewriter &rewriter) const {
|
||||
if (shape.empty())
|
||||
return rawDataPtr;
|
||||
|
||||
|
@ -779,16 +761,15 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
}
|
||||
|
||||
Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
|
||||
ArrayRef<Value *> indices, FuncBuilder &rewriter,
|
||||
ArrayRef<Value *> indices, PatternRewriter &rewriter,
|
||||
llvm::Module &module) const {
|
||||
auto ptrType = getMemRefElementPtrType(type, this->lowering);
|
||||
auto shape = type.getShape();
|
||||
if (type.hasStaticShape()) {
|
||||
// NB: If memref was statically-shaped, dataPtr is pointer to raw data.
|
||||
return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
|
||||
} else {
|
||||
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
|
||||
}
|
||||
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -797,8 +778,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
||||
using Base::Base;
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loadOp = cast<LoadOp>(op);
|
||||
auto type = loadOp.getMemRefType();
|
||||
|
||||
|
@ -806,10 +787,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
|||
operands.drop_front(), rewriter, getModule());
|
||||
auto elementType = lowering.convertType(type.getElementType());
|
||||
|
||||
SmallVector<Value *, 4> results;
|
||||
results.push_back(rewriter.create<LLVM::LoadOp>(
|
||||
op->getLoc(), elementType, ArrayRef<Value *>{dataPtr}));
|
||||
return results;
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
|
||||
ArrayRef<Value *>{dataPtr});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -818,16 +797,13 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
|||
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
||||
using Base::Base;
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
auto storeOp = cast<StoreOp>(op);
|
||||
auto type = storeOp.getMemRefType();
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto type = cast<StoreOp>(op).getMemRefType();
|
||||
|
||||
Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
|
||||
operands.drop_front(2), rewriter, getModule());
|
||||
|
||||
rewriter.create<LLVM::StoreOp>(op->getLoc(), operands[0], dataPtr);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -838,10 +814,10 @@ struct OneToOneLLVMTerminatorLowering
|
|||
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
|
||||
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
||||
|
||||
void rewriteTerminator(Operation *op, ArrayRef<Value *> properOperands,
|
||||
ArrayRef<Block *> destinations,
|
||||
ArrayRef<ArrayRef<Value *>> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> properOperands,
|
||||
ArrayRef<Block *> destinations,
|
||||
ArrayRef<ArrayRef<Value *>> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.create<TargetOp>(op->getLoc(), properOperands, destinations,
|
||||
operands, op->getAttrs());
|
||||
}
|
||||
|
@ -856,8 +832,8 @@ struct OneToOneLLVMTerminatorLowering
|
|||
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
||||
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op->getNumOperands();
|
||||
|
||||
// If ReturnOp has 0 or 1 operand, create it and return immediately.
|
||||
|
@ -865,14 +841,14 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||
rewriter.create<LLVM::ReturnOp>(
|
||||
op->getLoc(), llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
|
||||
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
|
||||
return {};
|
||||
return;
|
||||
}
|
||||
if (numArguments == 1) {
|
||||
rewriter.create<LLVM::ReturnOp>(
|
||||
op->getLoc(), llvm::ArrayRef<Value *>(operands.front()),
|
||||
llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
|
||||
op->getAttrs());
|
||||
return {};
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, we need to pack the arguments into an LLVM struct type before
|
||||
|
@ -888,7 +864,6 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||
rewriter.create<LLVM::ReturnOp>(
|
||||
op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
|
||||
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -954,20 +929,19 @@ void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
|
|||
// Create a set of converters that live in the pass object by passing them a
|
||||
// reference to the LLVM IR dialect. Store the module associated with the
|
||||
// dialect for further type conversion.
|
||||
llvm::DenseSet<DialectOpConversion *>
|
||||
LLVMLowering::initConverters(MLIRContext *mlirContext) {
|
||||
converterStorage.Reset();
|
||||
void LLVMLowering::initConverters(OwningRewritePatternList &patterns,
|
||||
MLIRContext *mlirContext) {
|
||||
llvmDialect = mlirContext->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
if (!llvmDialect) {
|
||||
mlirContext->emitError(UnknownLoc::get(mlirContext),
|
||||
"LLVM IR dialect is not registered");
|
||||
return {};
|
||||
return;
|
||||
}
|
||||
|
||||
module = &llvmDialect->getLLVMModule();
|
||||
|
||||
// FIXME: this should be tablegen'ed
|
||||
auto converters = ConversionListBuilder<
|
||||
ConversionListBuilder<
|
||||
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
|
||||
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
|
||||
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
|
||||
|
@ -975,11 +949,8 @@ LLVMLowering::initConverters(MLIRContext *mlirContext) {
|
|||
LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
|
||||
OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
|
||||
ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
|
||||
SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect,
|
||||
*this);
|
||||
auto extraConverters = initAdditionalConverters();
|
||||
converters.insert(extraConverters.begin(), extraConverters.end());
|
||||
return converters;
|
||||
SubIOpLowering, XOrOpLowering>::build(patterns, *llvmDialect, *this);
|
||||
initAdditionalConverters(patterns);
|
||||
}
|
||||
|
||||
// Convert types using the stored LLVM IR module.
|
||||
|
|
|
@ -147,7 +147,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
|||
|
||||
// Create an array attribute containing integer attributes with values provided
|
||||
// in `position`.
|
||||
static ArrayAttr positionAttr(FuncBuilder &builder, ArrayRef<int> position) {
|
||||
static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
attrs.reserve(position.size());
|
||||
for (auto p : position)
|
||||
|
@ -162,8 +162,8 @@ public:
|
|||
LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto indexType = IndexType::get(op->getContext());
|
||||
auto voidPtrTy = LLVM::LLVMType::get(
|
||||
op->getContext(),
|
||||
|
@ -207,7 +207,7 @@ public:
|
|||
positionAttr(rewriter, 0));
|
||||
desc = insertvalue(bufferDescriptorType, desc, size,
|
||||
positionAttr(rewriter, 1));
|
||||
return {desc};
|
||||
rewriter.replaceOp(op, desc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -219,8 +219,8 @@ public:
|
|||
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto voidPtrTy = LLVM::LLVMType::get(
|
||||
op->getContext(),
|
||||
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
|
||||
|
@ -243,8 +243,6 @@ public:
|
|||
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
|
||||
positionAttr(rewriter, 0)));
|
||||
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
|
||||
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -254,11 +252,12 @@ public:
|
|||
BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto int64Ty = lowering.convertType(operands[0]->getType());
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
return {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))};
|
||||
rewriter.replaceOp(
|
||||
op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -268,14 +267,16 @@ public:
|
|||
explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dimOp = cast<linalg::DimOp>(op);
|
||||
auto indexTy = lowering.convertType(rewriter.getIndexType());
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
return {extractvalue(
|
||||
indexTy, operands[0],
|
||||
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
|
||||
rewriter.replaceOp(
|
||||
op,
|
||||
{extractvalue(
|
||||
indexTy, operands[0],
|
||||
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -293,7 +294,8 @@ public:
|
|||
// descriptor to emit IR iteratively computing the actual offset, followed by
|
||||
// a getelementptr. This must be called under an edsc::ScopedContext.
|
||||
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
|
||||
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
|
||||
ArrayRef<Value *> indices,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loadOp = cast<Op>(op);
|
||||
auto elementTy = rewriter.getType<LLVMType>(
|
||||
getPtrToElementType(loadOp.getViewType(), lowering));
|
||||
|
@ -320,15 +322,14 @@ public:
|
|||
// an LLVM IR load.
|
||||
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
||||
using Base::Base;
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
auto elementTy = lowering.convertType(*op->getResultTypes().begin());
|
||||
Value *viewDescriptor = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front();
|
||||
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
Value *element = llvm_load(elementTy, ptr);
|
||||
return {element};
|
||||
rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -338,8 +339,8 @@ public:
|
|||
explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto rangeOp = cast<RangeOp>(op);
|
||||
auto rangeDescriptorTy =
|
||||
convertLinalgType(rangeOp.getResult()->getType(), lowering);
|
||||
|
@ -354,8 +355,7 @@ public:
|
|||
positionAttr(rewriter, 1));
|
||||
desc = insertvalue(rangeDescriptorTy, desc, operands[2],
|
||||
positionAttr(rewriter, 2));
|
||||
|
||||
return {desc};
|
||||
rewriter.replaceOp(op, desc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -367,8 +367,8 @@ public:
|
|||
: LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto rangeIntersectOp = cast<RangeIntersectOp>(op);
|
||||
auto rangeDescriptorTy =
|
||||
convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
|
||||
|
@ -400,8 +400,7 @@ public:
|
|||
// TODO(ntv): this assumes both steps are one for now. Enforce and extend.
|
||||
desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2),
|
||||
positionAttr(rewriter, 2));
|
||||
|
||||
return {desc};
|
||||
rewriter.replaceOp(op, desc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -410,8 +409,8 @@ public:
|
|||
explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto sliceOp = cast<SliceOp>(op);
|
||||
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
|
||||
auto viewType = sliceOp.getBaseViewType();
|
||||
|
@ -483,7 +482,7 @@ public:
|
|||
++i;
|
||||
}
|
||||
|
||||
return {desc};
|
||||
rewriter.replaceOp(op, desc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -491,15 +490,14 @@ public:
|
|||
// an LLVM IR store.
|
||||
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
|
||||
using Base::Base;
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
Value *data = operands[0];
|
||||
Value *viewDescriptor = operands[1];
|
||||
ArrayRef<Value *> indices = operands.drop_front(2);
|
||||
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
llvm_store(data, ptr);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -508,8 +506,8 @@ public:
|
|||
explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto viewOp = cast<ViewOp>(op);
|
||||
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
|
||||
auto elementTy = rewriter.getType<LLVMType>(
|
||||
|
@ -556,7 +554,7 @@ public:
|
|||
runningStride = mul(runningStride, max);
|
||||
}
|
||||
|
||||
return {desc};
|
||||
rewriter.replaceOp(op, desc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -568,18 +566,20 @@ public:
|
|||
|
||||
static StringRef libraryFunctionName() { return "linalg_dot"; }
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto *f =
|
||||
op->getFunction()->getModule()->getNamedFunction(libraryFunctionName());
|
||||
if (!f)
|
||||
if (!f) {
|
||||
op->emitError("Could not find function: " + libraryFunctionName() +
|
||||
"in lowering to LLVM ");
|
||||
return;
|
||||
}
|
||||
|
||||
auto fAttr = rewriter.getFunctionAttr(f);
|
||||
auto named = rewriter.getNamedAttr("callee", fAttr);
|
||||
rewriter.create<LLVM::CallOp>(op->getLoc(), operands,
|
||||
ArrayRef<NamedAttribute>{named});
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -587,14 +587,13 @@ namespace {
|
|||
// The conversion class from Linalg to LLVMIR.
|
||||
class Lowering : public LLVMLowering {
|
||||
protected:
|
||||
llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
|
||||
void initAdditionalConverters(OwningRewritePatternList &patterns) override {
|
||||
return ConversionListBuilder<
|
||||
BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
|
||||
LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion,
|
||||
SliceOpConversion, StoreOpConversion,
|
||||
ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
|
||||
*this);
|
||||
ViewOpConversion>::build(patterns, llvmDialect->getContext(), *this);
|
||||
}
|
||||
|
||||
Type convertAdditionalType(Type t) override {
|
||||
|
|
|
@ -28,6 +28,84 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// This class implements a pattern rewriter for DialectOpConversion patterns.
|
||||
/// It automatically performs remapping of replaced operation values.
|
||||
struct DialectConversionRewriter final : public PatternRewriter {
|
||||
DialectConversionRewriter(Function *fn) : PatternRewriter(fn) {}
|
||||
~DialectConversionRewriter() = default;
|
||||
|
||||
// Implement the hook for replacing an operation with new values.
|
||||
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead) override {
|
||||
assert(newValues.size() == op->getNumResults());
|
||||
for (unsigned i = 0, e = newValues.size(); i < e; ++i)
|
||||
mapping.map(op->getResult(i), newValues[i]);
|
||||
}
|
||||
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
Operation *createOperation(const OperationState &state) override {
|
||||
return FuncBuilder::createOperation(state);
|
||||
}
|
||||
|
||||
void lookupValues(Operation::operand_range operands,
|
||||
SmallVectorImpl<Value *> &remapped) {
|
||||
remapped.reserve(llvm::size(operands));
|
||||
for (Value *operand : operands) {
|
||||
Value *value = mapping.lookupOrNull(operand);
|
||||
assert(value && "converting op before ops defining its operands");
|
||||
remapped.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Mapping between values(blocks) in the original function and in the new
|
||||
// function.
|
||||
BlockAndValueMapping mapping;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Rewrite the IR rooted at the specified operation with the result of
|
||||
/// this pattern, generating any new operations with the specified
|
||||
/// builder. If an unexpected error is encountered (an internal
|
||||
/// compiler error), it is emitted through the normal MLIR diagnostic
|
||||
/// hooks and the IR is left in a valid state.
|
||||
void DialectOpConversion::rewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
SmallVector<Value *, 4> operands;
|
||||
auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
|
||||
dialectRewriter.lookupValues(op->getOperands(), operands);
|
||||
|
||||
// If this operation has no successors, invoke the rewrite directly.
|
||||
if (op->getNumSuccessors() == 0)
|
||||
return rewrite(op, operands, rewriter);
|
||||
|
||||
// Otherwise, we need to remap the successors.
|
||||
SmallVector<Block *, 2> destinations;
|
||||
destinations.reserve(op->getNumSuccessors());
|
||||
|
||||
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
|
||||
unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
|
||||
for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
|
||||
// Lookup the successor.
|
||||
auto *successor = dialectRewriter.mapping.lookupOrNull(op->getSuccessor(i));
|
||||
assert(successor && "block was not remapped");
|
||||
destinations.push_back(successor);
|
||||
|
||||
// Lookup the successors operands.
|
||||
unsigned n = op->getNumSuccessorOperands(i);
|
||||
operandsPerDestination.push_back(
|
||||
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
|
||||
seen += n;
|
||||
}
|
||||
|
||||
// Rewrite the operation.
|
||||
rewrite(op,
|
||||
llvm::makeArrayRef(operands.data(),
|
||||
operands.data() + firstSuccessorOperand),
|
||||
destinations, operandsPerDestination, rewriter);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace impl {
|
||||
// Implementation detail class of the DialectConversion pass. Performs
|
||||
|
@ -36,24 +114,14 @@ namespace impl {
|
|||
// old functions with the new ones in the module.
|
||||
class FunctionConversion {
|
||||
public:
|
||||
// Entry point. Uses hooks defined in `conversion` to obtain the list of
|
||||
// conversion patterns and to convert function and block argument types.
|
||||
// Converts the `module` in-place by replacing all existing functions with the
|
||||
// converted ones.
|
||||
static LogicalResult convert(DialectConversion *conversion, Module *module);
|
||||
|
||||
private:
|
||||
// Constructs a FunctionConversion by storing the hooks.
|
||||
explicit FunctionConversion(DialectConversion *conversion)
|
||||
: dialectConversion(conversion) {}
|
||||
explicit FunctionConversion(DialectConversion *conversion, Function *func,
|
||||
RewritePatternMatcher &matcher)
|
||||
: dialectConversion(conversion), rewriter(func), matcher(matcher) {}
|
||||
|
||||
// Utility that looks up a list of value in the value remapping table. Returns
|
||||
// an empty vector if one of the values is not mapped yet.
|
||||
SmallVector<Value *, 4> lookupValues(Operation::operand_range operands);
|
||||
|
||||
// Converts the given function to the dialect using hooks defined in
|
||||
// Converts the current function to the dialect using hooks defined in
|
||||
// `dialectConversion`. Returns the converted function or `nullptr` on error.
|
||||
Function *convertFunction(Function *f);
|
||||
Function *convertFunction();
|
||||
|
||||
// Converts the given region starting from the entry block and following the
|
||||
// block successors. Returns the converted region or `nullptr` on error.
|
||||
|
@ -61,19 +129,6 @@ private:
|
|||
std::unique_ptr<Region> convertRegion(MLIRContext *context, Region *region,
|
||||
RegionParent *parent);
|
||||
|
||||
// Converts an operation with successors. Extracts the converted operands
|
||||
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
|
||||
// passes them to `converter->rewriteTerminator` function defined in the
|
||||
// pattern, together with `builder`.
|
||||
LogicalResult convertOpWithSuccessors(DialectOpConversion *converter,
|
||||
Operation *op, FuncBuilder &builder);
|
||||
|
||||
// Converts an operation without successors. Extracts the converted operands
|
||||
// from `valueRemapping` and passes them to the `converter->rewrite` function
|
||||
// defined in the pattern, together with `builder`.
|
||||
LogicalResult convertOp(DialectOpConversion *converter, Operation *op,
|
||||
FuncBuilder &builder);
|
||||
|
||||
// Converts a block by traversing its operations sequentially, looking for
|
||||
// the first pattern match and dispatching the operation conversion to
|
||||
// either `convertOp` or `convertOpWithSuccessors` depending on the presence
|
||||
|
@ -81,128 +136,40 @@ private:
|
|||
//
|
||||
// After converting operations, traverses the successor blocks unless they
|
||||
// have been visited already as indicated in `visitedBlocks`.
|
||||
LogicalResult convertBlock(Block *block, FuncBuilder &builder,
|
||||
LogicalResult convertBlock(Block *block,
|
||||
llvm::DenseSet<Block *> &visitedBlocks);
|
||||
|
||||
// Converts the module as follows.
|
||||
// 1. Call `convertFunction` on each function of the module and collect the
|
||||
// mapping between old and new functions.
|
||||
// 2. Remap all function attributes in the new functions to point to the new
|
||||
// functions instead of the old ones.
|
||||
// 3. Replace old functions with the new in the module.
|
||||
LogicalResult run(Module *m);
|
||||
|
||||
// Pointer to a specific dialect pass.
|
||||
DialectConversion *dialectConversion;
|
||||
|
||||
// Set of known conversion patterns.
|
||||
llvm::DenseSet<DialectOpConversion *> conversions;
|
||||
/// The writer used when rewriting operations.
|
||||
DialectConversionRewriter rewriter;
|
||||
|
||||
// Mapping between values(blocks) in the original function and in the new
|
||||
// function.
|
||||
BlockAndValueMapping mapping;
|
||||
/// The matcher use when converting operations.
|
||||
RewritePatternMatcher &matcher;
|
||||
};
|
||||
} // end namespace impl
|
||||
} // end namespace mlir
|
||||
|
||||
SmallVector<Value *, 4>
|
||||
impl::FunctionConversion::lookupValues(Operation::operand_range operands) {
|
||||
SmallVector<Value *, 4> remapped;
|
||||
remapped.reserve(llvm::size(operands));
|
||||
for (Value *operand : operands) {
|
||||
Value *value = mapping.lookupOrNull(operand);
|
||||
if (!value)
|
||||
return {};
|
||||
remapped.push_back(value);
|
||||
}
|
||||
return remapped;
|
||||
}
|
||||
|
||||
LogicalResult impl::FunctionConversion::convertOpWithSuccessors(
|
||||
DialectOpConversion *converter, Operation *op, FuncBuilder &builder) {
|
||||
SmallVector<Block *, 2> destinations;
|
||||
destinations.reserve(op->getNumSuccessors());
|
||||
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
|
||||
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
||||
"converting op before ops defining its operands");
|
||||
|
||||
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
|
||||
unsigned numSuccessorOperands = 0;
|
||||
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i)
|
||||
numSuccessorOperands += op->getNumSuccessorOperands(i);
|
||||
unsigned seen = 0;
|
||||
unsigned firstSuccessorOperand = op->getNumOperands() - numSuccessorOperands;
|
||||
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) {
|
||||
Block *successor = mapping.lookupOrNull(op->getSuccessor(i));
|
||||
assert(successor && "block was not remapped");
|
||||
destinations.push_back(successor);
|
||||
unsigned n = op->getNumSuccessorOperands(i);
|
||||
operandsPerDestination.push_back(
|
||||
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
|
||||
seen += n;
|
||||
}
|
||||
converter->rewriteTerminator(
|
||||
op,
|
||||
llvm::makeArrayRef(operands.data(),
|
||||
operands.data() + firstSuccessorOperand),
|
||||
destinations, operandsPerDestination, builder);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
impl::FunctionConversion::convertOp(DialectOpConversion *converter,
|
||||
Operation *op, FuncBuilder &builder) {
|
||||
auto operands = lookupValues(op->getOperands());
|
||||
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
||||
"converting op before ops defining its operands");
|
||||
|
||||
auto results = converter->rewrite(op, operands, builder);
|
||||
if (results.size() != op->getNumResults())
|
||||
return op->emitError("rewriting produced a different number of results");
|
||||
|
||||
for (unsigned i = 0, e = results.size(); i < e; ++i)
|
||||
mapping.map(op->getResult(i), results[i]);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
|
||||
impl::FunctionConversion::convertBlock(Block *block,
|
||||
llvm::DenseSet<Block *> &visitedBlocks) {
|
||||
// First, add the current block to the list of visited blocks.
|
||||
visitedBlocks.insert(block);
|
||||
// Setup the builder to the insert to the converted block.
|
||||
builder.setInsertionPointToStart(mapping.lookupOrNull(block));
|
||||
rewriter.setInsertionPointToStart(rewriter.mapping.lookupOrNull(block));
|
||||
|
||||
// Iterate over ops and convert them.
|
||||
for (Operation &op : *block) {
|
||||
// Find the first matching conversion and apply it.
|
||||
bool converted = false;
|
||||
for (auto *conversion : conversions) {
|
||||
// Ignore patterns that are for the wrong root or are impossible to match.
|
||||
if (conversion->getRootKind() != op.getName() ||
|
||||
conversion->getBenefit().isImpossibleToMatch())
|
||||
continue;
|
||||
if (matcher.matchAndRewrite(&op, rewriter))
|
||||
continue;
|
||||
|
||||
if (!conversion->match(&op))
|
||||
continue;
|
||||
|
||||
if (op.getNumSuccessors() != 0) {
|
||||
if (failed(convertOpWithSuccessors(conversion, &op, builder)))
|
||||
return failure();
|
||||
} else if (failed(convertOp(conversion, &op, builder))) {
|
||||
return failure();
|
||||
}
|
||||
converted = true;
|
||||
break;
|
||||
}
|
||||
// If there is no conversion provided for the op, clone the op and convert
|
||||
// its regions, if any.
|
||||
if (!converted) {
|
||||
auto *newOp = builder.cloneWithoutRegions(op, mapping);
|
||||
for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
|
||||
auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
|
||||
newOp->getRegion(i).takeBody(*newRegion);
|
||||
}
|
||||
auto *newOp = rewriter.cloneWithoutRegions(op, rewriter.mapping);
|
||||
for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
|
||||
auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
|
||||
newOp->getRegion(i).takeBody(*newRegion);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -210,7 +177,7 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
|
|||
for (Block *succ : block->getSuccessors()) {
|
||||
if (visitedBlocks.count(succ) != 0)
|
||||
continue;
|
||||
if (failed(convertBlock(succ, builder, visitedBlocks)))
|
||||
if (failed(convertBlock(succ, visitedBlocks)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
@ -234,33 +201,31 @@ impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region,
|
|||
for (Block &block : *region) {
|
||||
auto *newBlock = new Block;
|
||||
newRegion->push_back(newBlock);
|
||||
mapping.map(&block, newBlock);
|
||||
rewriter.mapping.map(&block, newBlock);
|
||||
for (auto *arg : block.getArguments()) {
|
||||
auto convertedType = dialectConversion->convertType(arg->getType());
|
||||
if (!convertedType)
|
||||
return emitError("could not convert block argument type");
|
||||
newBlock->addArgument(convertedType);
|
||||
mapping.map(arg, *newBlock->args_rbegin());
|
||||
rewriter.mapping.map(arg, *newBlock->args_rbegin());
|
||||
}
|
||||
}
|
||||
|
||||
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
||||
// before uses in dominated blocks.
|
||||
llvm::DenseSet<Block *> visitedBlocks;
|
||||
FuncBuilder builder(&newRegion->front());
|
||||
if (failed(convertBlock(®ion->front(), builder, visitedBlocks)))
|
||||
if (failed(convertBlock(®ion->front(), visitedBlocks)))
|
||||
return nullptr;
|
||||
|
||||
// If some blocks are not reachable through successor chains, they should have
|
||||
// been removed by the DCE before this.
|
||||
|
||||
if (visitedBlocks.size() != std::distance(region->begin(), region->end()))
|
||||
return emitError("unreachable blocks were not converted");
|
||||
return newRegion;
|
||||
}
|
||||
|
||||
Function *impl::FunctionConversion::convertFunction(Function *f) {
|
||||
assert(f && "expected function");
|
||||
Function *impl::FunctionConversion::convertFunction() {
|
||||
Function *f = rewriter.getFunction();
|
||||
MLIRContext *context = f->getContext();
|
||||
auto emitError = [context](llvm::Twine f) -> Function * {
|
||||
context->emitError(UnknownLoc::get(context), f.str());
|
||||
|
@ -278,8 +243,8 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
|
|||
f->getLoc(), f->getName().strref(), newFunctionType.cast<FunctionType>(),
|
||||
f->getAttrs(), newFunctionArgAttrs);
|
||||
|
||||
// Return early if the function has no blocks.
|
||||
if (f->getBlocks().empty())
|
||||
// Return early if the function is external.
|
||||
if (f->isExternal())
|
||||
return newFunction.release();
|
||||
|
||||
auto newBody = convertRegion(context, &f->getBody(), f);
|
||||
|
@ -290,54 +255,6 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
|
|||
return newFunction.release();
|
||||
}
|
||||
|
||||
LogicalResult impl::FunctionConversion::convert(DialectConversion *conversion,
|
||||
Module *module) {
|
||||
return impl::FunctionConversion(conversion).run(module);
|
||||
}
|
||||
|
||||
LogicalResult impl::FunctionConversion::run(Module *module) {
|
||||
if (!module)
|
||||
return failure();
|
||||
|
||||
MLIRContext *context = module->getContext();
|
||||
conversions = dialectConversion->initConverters(context);
|
||||
|
||||
// Convert the functions but don't add them to the module yet to avoid
|
||||
// converted functions to be converted again.
|
||||
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
|
||||
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
|
||||
originalFuncs.reserve(module->getFunctions().size());
|
||||
for (auto &func : *module)
|
||||
originalFuncs.push_back(&func);
|
||||
convertedFuncs.reserve(module->getFunctions().size());
|
||||
for (auto *func : originalFuncs) {
|
||||
Function *converted = convertFunction(func);
|
||||
if (!converted)
|
||||
return failure();
|
||||
|
||||
auto origFuncAttr = FunctionAttr::get(func);
|
||||
auto convertedFuncAttr = FunctionAttr::get(converted);
|
||||
convertedFuncs.push_back(converted);
|
||||
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
|
||||
}
|
||||
|
||||
// Remap function attributes in the converted functions (they are not yet in
|
||||
// the module). Original functions will disappear anyway so there is no
|
||||
// need to remap attributes in them.
|
||||
for (const auto &funcPair : functionAttrRemapping) {
|
||||
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
|
||||
}
|
||||
|
||||
// Remove original functions from the module, then insert converted
|
||||
// functions. The order is important to avoid name collisions.
|
||||
for (auto &func : originalFuncs)
|
||||
func->erase();
|
||||
for (auto *func : convertedFuncs)
|
||||
module->getFunctions().push_back(func);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Create a function type with arguments and results converted, and argument
|
||||
// attributes passed through.
|
||||
FunctionType DialectConversion::convertFunctionSignatureType(
|
||||
|
@ -363,6 +280,55 @@ FunctionType DialectConversion::convertFunctionSignatureType(
|
|||
return FunctionType::get(arguments, results, type.getContext());
|
||||
}
|
||||
|
||||
LogicalResult DialectConversion::convert(Module *m) {
|
||||
return impl::FunctionConversion::convert(this, m);
|
||||
// Converts the module as follows.
|
||||
// 1. Call `convertFunction` on each function of the module and collect the
|
||||
// mapping between old and new functions.
|
||||
// 2. Remap all function attributes in the new functions to point to the new
|
||||
// functions instead of the old ones.
|
||||
// 3. Replace old functions with the new in the module.
|
||||
LogicalResult DialectConversion::convert(Module *module) {
|
||||
if (!module)
|
||||
return failure();
|
||||
|
||||
// Grab the conversion patterns from the converter and create the pattern
|
||||
// matcher.
|
||||
MLIRContext *context = module->getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
initConverters(patterns, context);
|
||||
RewritePatternMatcher matcher(std::move(patterns));
|
||||
|
||||
// Convert the functions but don't add them to the module yet to avoid
|
||||
// converted functions to be converted again.
|
||||
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
|
||||
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
|
||||
originalFuncs.reserve(module->getFunctions().size());
|
||||
for (auto &func : *module)
|
||||
originalFuncs.push_back(&func);
|
||||
convertedFuncs.reserve(module->getFunctions().size());
|
||||
for (auto *func : originalFuncs) {
|
||||
impl::FunctionConversion converter(this, func, matcher);
|
||||
Function *converted = converter.convertFunction();
|
||||
if (!converted)
|
||||
return failure();
|
||||
|
||||
auto origFuncAttr = FunctionAttr::get(func);
|
||||
auto convertedFuncAttr = FunctionAttr::get(converted);
|
||||
convertedFuncs.push_back(converted);
|
||||
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
|
||||
}
|
||||
|
||||
// Remap function attributes in the converted functions (they are not yet in
|
||||
// the module). Original functions will disappear anyway so there is no
|
||||
// need to remap attributes in them.
|
||||
for (const auto &funcPair : functionAttrRemapping)
|
||||
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
|
||||
|
||||
// Remove original functions from the module, then insert converted
|
||||
// functions. The order is important to avoid name collisions.
|
||||
for (auto &func : originalFuncs)
|
||||
func->erase();
|
||||
for (auto *func : convertedFuncs)
|
||||
module->getFunctions().push_back(func);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
|
|||
public:
|
||||
explicit GreedyPatternRewriteDriver(Function &fn,
|
||||
OwningRewritePatternList &&patterns)
|
||||
: PatternRewriter(&fn), matcher(std::move(patterns), *this) {
|
||||
: PatternRewriter(&fn), matcher(std::move(patterns)) {
|
||||
worklist.reserve(64);
|
||||
}
|
||||
|
||||
|
@ -202,7 +202,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
|||
// Try to match one of the canonicalization patterns. The rewriter is
|
||||
// automatically notified of any necessary changes, so there is nothing
|
||||
// else to do here.
|
||||
changed |= matcher.matchAndRewrite(op);
|
||||
changed |= matcher.matchAndRewrite(op, *this);
|
||||
}
|
||||
} while (changed && ++i < maxIterations);
|
||||
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
|
||||
|
|
Loading…
Reference in New Issue