[mlir] DialectConversion: support block creation in ConversionPatternRewriter

PatternRewriter and derived classes provide a set of virtual methods to
manipulate blocks, which ConversionPatternRewriter overrides to keep track of
the manipulations and undo them in case the conversion fails. However, one can
currently create a block only by splitting another block into two. This not
only makes the API inconsistent (`splitBlock` is allowed in conversion
patterns, but `createBlock` is not), but it also make it impossible for one to
create blocks with argument lists different from those of already existing
blocks since in-place block updates are not supported either. Such
functionality precludes dialect conversion infrastructure from being used more
extensively on region-containing ops, for example, for value-returning "if"
operations. At the same time, ConversionPatternRewriter already allows one to
undo block creation as block creation is one of the primitive operations in
already supported region inlining.

Support block creation in conversion patterns by hooking `createBlock` on the
block action undo mechanism. This requires to make `Builder::createBlock`
virtual, similarly to Op insertion. This is a minimal change to the Builder
infrastructure that will later help support additional use cases such as block
signature changes. `createBlock` now additionally takes the types of the block
arguments that are added immediately so as to avoid in-place argument list
manipulation that would be illegal in conversion patterns.
This commit is contained in:
Alex Zinenko 2020-04-03 19:53:13 +02:00
parent b600809688
commit f27f1e8c27
6 changed files with 109 additions and 20 deletions

View File

@ -298,13 +298,15 @@ public:
/// Insert the given operation at the current insertion point and return it.
virtual Operation *insert(Operation *op);
/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *createBlock(Region *parent, Region::iterator insertPt = {});
/// Add new block with 'argTypes' arguments and set the insertion point to the
/// end of it. The block is inserted at the provided insertion point of
/// 'parent'.
virtual Block *createBlock(Region *parent, Region::iterator insertPt = {},
TypeRange argTypes = llvm::None);
/// Add new block and set the insertion point to the end of it. The block is
/// placed before 'insertBefore'.
Block *createBlock(Block *insertBefore);
/// Add new block with 'argTypes' arguments and set the insertion point to the
/// end of it. The block is placed before 'insertBefore'.
Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
/// Returns the current block of the builder.
Block *getBlock() const { return block; }

View File

@ -344,6 +344,10 @@ public:
/// otherwise an assert will be issued.
void eraseOp(Operation *op) override;
/// PatternRewriter hook for creating a new block with the given arguments.
Block *createBlock(Region *parent, Region::iterator insertPt = {},
TypeRange argTypes = llvm::None) override;
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;

View File

@ -339,24 +339,28 @@ Operation *OpBuilder::insert(Operation *op) {
return op;
}
/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
/// Add new block with 'argTypes' arguments and set the insertion point to the
/// end of it. The block is inserted at the provided insertion point of
/// 'parent'.
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
TypeRange argTypes) {
assert(parent && "expected valid parent region");
if (insertPt == Region::iterator())
insertPt = parent->end();
Block *b = new Block();
b->addArguments(argTypes);
parent->getBlocks().insert(insertPt, b);
setInsertionPointToEnd(b);
return b;
}
/// Add new block and set the insertion point to the end of it. The block is
/// placed before 'insertBefore'.
Block *OpBuilder::createBlock(Block *insertBefore) {
/// Add new block with 'argTypes' arguments and set the insertion point to the
/// end of it. The block is placed before 'insertBefore'.
Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
assert(insertBefore && "expected valid insertion block");
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
argTypes);
}
/// Create an operation given the fields represented as an OperationState.

View File

@ -585,6 +585,9 @@ struct ConversionPatternRewriterImpl {
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues);
/// Notifies that a block was created.
void notifyCreatedBlock(Block *block);
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
@ -804,6 +807,10 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op,
markNestedOpsIgnored(op);
}
void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
blockActions.push_back(BlockAction::getCreate(block));
}
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
Block *continuation) {
blockActions.push_back(BlockAction::getSplit(continuation, block));
@ -910,6 +917,15 @@ Value ConversionPatternRewriter::getRemappedValue(Value key) {
return impl->mapping.lookupOrDefault(key);
}
/// PatternRewriter hook for creating a new block with the given arguments.
Block *ConversionPatternRewriter::createBlock(Region *parent,
Region::iterator insertPtr,
TypeRange argTypes) {
Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes);
impl->notifyCreatedBlock(block);
return block;
}
/// PatternRewriter hook for splitting a block into two parts.
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {

View File

@ -130,6 +130,19 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) {
return %0 : i32
}
// CHECK-LABEL: @create_block
func @create_block() {
"test.container"() ({
// Check that we created a block with arguments.
// CHECK-NOT: test.create_block
// CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
// CHECK: test.finish
"test.create_block"() : () -> ()
"test.finish"() : () -> ()
}) : () -> ()
return
}
// -----
func @fail_to_convert_illegal_op() -> i32 {
@ -163,3 +176,17 @@ func @fail_to_convert_region() {
}) : () -> ()
return
}
// -----
// CHECK-LABEL: @create_illegal_block
func @create_illegal_block() {
"test.container"() ({
// Check that we can undo block creation, i.e. that the block was removed.
// CHECK: test.create_illegal_block
// CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
"test.create_illegal_block"() : () -> ()
"test.finish"() : () -> ()
}) : () -> ()
return
}

View File

@ -183,6 +183,41 @@ struct TestRegionRewriteUndo : public RewritePattern {
return success();
}
};
/// A simple pattern that creates a block at the end of the parent region of the
/// matched operation.
struct TestCreateBlock : public RewritePattern {
TestCreateBlock(MLIRContext *ctx)
: RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
Region &region = *op->getParentRegion();
Type i32Type = rewriter.getIntegerType(32);
rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
rewriter.create<TerminatorOp>(op->getLoc());
rewriter.replaceOp(op, {});
return success();
}
};
/// A simple pattern that creates a block containing an invalid operaiton in
/// order to trigger the block creation undo mechanism.
struct TestCreateIllegalBlock : public RewritePattern {
TestCreateIllegalBlock(MLIRContext *ctx)
: RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
Region &region = *op->getParentRegion();
Type i32Type = rewriter.getIntegerType(32);
rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
// Create an illegal op to ensure the conversion fails.
rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
rewriter.create<TerminatorOp>(op->getLoc());
rewriter.replaceOp(op, {});
return success();
}
};
//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing
@ -373,12 +408,12 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
patterns
.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement>(&getContext());
patterns.insert<
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
@ -388,7 +423,8 @@ struct TestLegalizePatternDriver
// Define the conversion target used for the test.
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
TerminatorOp>();
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {