From f27f1e8c27b1d7cf624877e798999244a72adb41 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 3 Apr 2020 19:53:13 +0200 Subject: [PATCH] [mlir] DialectConversion: support block creation in ConversionPatternRewriter PatternRewriter and derived classes provide a set of virtual methods to manipulate blocks, which ConversionPatternRewriter overrides to keep track of the manipulations and undo them in case the conversion fails. However, one can currently create a block only by splitting another block into two. This not only makes the API inconsistent (`splitBlock` is allowed in conversion patterns, but `createBlock` is not), but it also make it impossible for one to create blocks with argument lists different from those of already existing blocks since in-place block updates are not supported either. Such functionality precludes dialect conversion infrastructure from being used more extensively on region-containing ops, for example, for value-returning "if" operations. At the same time, ConversionPatternRewriter already allows one to undo block creation as block creation is one of the primitive operations in already supported region inlining. Support block creation in conversion patterns by hooking `createBlock` on the block action undo mechanism. This requires to make `Builder::createBlock` virtual, similarly to Op insertion. This is a minimal change to the Builder infrastructure that will later help support additional use cases such as block signature changes. `createBlock` now additionally takes the types of the block arguments that are added immediately so as to avoid in-place argument list manipulation that would be illegal in conversion patterns. --- mlir/include/mlir/IR/Builders.h | 14 +++--- .../mlir/Transforms/DialectConversion.h | 4 ++ mlir/lib/IR/Builders.cpp | 18 ++++--- mlir/lib/Transforms/DialectConversion.cpp | 16 ++++++ mlir/test/Transforms/test-legalizer.mlir | 27 ++++++++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 50 ++++++++++++++++--- 6 files changed, 109 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1c6b16f22989..75f49e86d10a 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -298,13 +298,15 @@ public: /// Insert the given operation at the current insertion point and return it. virtual Operation *insert(Operation *op); - /// Add new block and set the insertion point to the end of it. The block is - /// inserted at the provided insertion point of 'parent'. - Block *createBlock(Region *parent, Region::iterator insertPt = {}); + /// Add new block with 'argTypes' arguments and set the insertion point to the + /// end of it. The block is inserted at the provided insertion point of + /// 'parent'. + virtual Block *createBlock(Region *parent, Region::iterator insertPt = {}, + TypeRange argTypes = llvm::None); - /// Add new block and set the insertion point to the end of it. The block is - /// placed before 'insertBefore'. - Block *createBlock(Block *insertBefore); + /// Add new block with 'argTypes' arguments and set the insertion point to the + /// end of it. The block is placed before 'insertBefore'. + Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None); /// Returns the current block of the builder. Block *getBlock() const { return block; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 776007347c5e..9ab3a715e0ab 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -344,6 +344,10 @@ public: /// otherwise an assert will be issued. void eraseOp(Operation *op) override; + /// PatternRewriter hook for creating a new block with the given arguments. + Block *createBlock(Region *parent, Region::iterator insertPt = {}, + TypeRange argTypes = llvm::None) override; + /// PatternRewriter hook for splitting a block into two parts. Block *splitBlock(Block *block, Block::iterator before) override; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 23536651f974..c8d5ea6b6ca9 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -339,24 +339,28 @@ Operation *OpBuilder::insert(Operation *op) { return op; } -/// Add new block and set the insertion point to the end of it. The block is -/// inserted at the provided insertion point of 'parent'. -Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) { +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is inserted at the provided insertion point of +/// 'parent'. +Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt, + TypeRange argTypes) { assert(parent && "expected valid parent region"); if (insertPt == Region::iterator()) insertPt = parent->end(); Block *b = new Block(); + b->addArguments(argTypes); parent->getBlocks().insert(insertPt, b); setInsertionPointToEnd(b); return b; } -/// Add new block and set the insertion point to the end of it. The block is -/// placed before 'insertBefore'. -Block *OpBuilder::createBlock(Block *insertBefore) { +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is placed before 'insertBefore'. +Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) { assert(insertBefore && "expected valid insertion block"); - return createBlock(insertBefore->getParent(), Region::iterator(insertBefore)); + return createBlock(insertBefore->getParent(), Region::iterator(insertBefore), + argTypes); } /// Create an operation given the fields represented as an OperationState. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 19304b3fb73f..725f5f4bb16e 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -585,6 +585,9 @@ struct ConversionPatternRewriterImpl { /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues); + /// Notifies that a block was created. + void notifyCreatedBlock(Block *block); + /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); @@ -804,6 +807,10 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op, markNestedOpsIgnored(op); } +void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { + blockActions.push_back(BlockAction::getCreate(block)); +} + void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, Block *continuation) { blockActions.push_back(BlockAction::getSplit(continuation, block)); @@ -910,6 +917,15 @@ Value ConversionPatternRewriter::getRemappedValue(Value key) { return impl->mapping.lookupOrDefault(key); } +/// PatternRewriter hook for creating a new block with the given arguments. +Block *ConversionPatternRewriter::createBlock(Region *parent, + Region::iterator insertPtr, + TypeRange argTypes) { + Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes); + impl->notifyCreatedBlock(block); + return block; +} + /// PatternRewriter hook for splitting a block into two parts. Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index bd73cf30639a..3305e017d5b3 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -130,6 +130,19 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) { return %0 : i32 } +// CHECK-LABEL: @create_block +func @create_block() { + "test.container"() ({ + // Check that we created a block with arguments. + // CHECK-NOT: test.create_block + // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + // CHECK: test.finish + "test.create_block"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + // ----- func @fail_to_convert_illegal_op() -> i32 { @@ -163,3 +176,17 @@ func @fail_to_convert_region() { }) : () -> () return } + +// ----- + +// CHECK-LABEL: @create_illegal_block +func @create_illegal_block() { + "test.container"() ({ + // Check that we can undo block creation, i.e. that the block was removed. + // CHECK: test.create_illegal_block + // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + "test.create_illegal_block"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 0b73f09c1943..23d650e15479 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -183,6 +183,41 @@ struct TestRegionRewriteUndo : public RewritePattern { return success(); } }; +/// A simple pattern that creates a block at the end of the parent region of the +/// matched operation. +struct TestCreateBlock : public RewritePattern { + TestCreateBlock(MLIRContext *ctx) + : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + Region ®ion = *op->getParentRegion(); + Type i32Type = rewriter.getIntegerType(32); + rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); + rewriter.create(op->getLoc()); + rewriter.replaceOp(op, {}); + return success(); + } +}; + +/// A simple pattern that creates a block containing an invalid operaiton in +/// order to trigger the block creation undo mechanism. +struct TestCreateIllegalBlock : public RewritePattern { + TestCreateIllegalBlock(MLIRContext *ctx) + : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + Region ®ion = *op->getParentRegion(); + Type i32Type = rewriter.getIntegerType(32); + rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); + // Create an illegal op to ensure the conversion fails. + rewriter.create(op->getLoc(), i32Type); + rewriter.create(op->getLoc()); + rewriter.replaceOp(op, {}); + return success(); + } +}; //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing @@ -373,12 +408,12 @@ struct TestLegalizePatternDriver TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - patterns - .insert(&getContext()); + patterns.insert< + TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, + TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType, + TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, + TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, + TestNonRootReplacement>(&getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); @@ -388,7 +423,8 @@ struct TestLegalizePatternDriver // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); target .addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) {