forked from OSchip/llvm-project
[mlir] DialectConversion: support block creation in ConversionPatternRewriter
PatternRewriter and derived classes provide a set of virtual methods to manipulate blocks, which ConversionPatternRewriter overrides to keep track of the manipulations and undo them in case the conversion fails. However, one can currently create a block only by splitting another block into two. This not only makes the API inconsistent (`splitBlock` is allowed in conversion patterns, but `createBlock` is not), but it also make it impossible for one to create blocks with argument lists different from those of already existing blocks since in-place block updates are not supported either. Such functionality precludes dialect conversion infrastructure from being used more extensively on region-containing ops, for example, for value-returning "if" operations. At the same time, ConversionPatternRewriter already allows one to undo block creation as block creation is one of the primitive operations in already supported region inlining. Support block creation in conversion patterns by hooking `createBlock` on the block action undo mechanism. This requires to make `Builder::createBlock` virtual, similarly to Op insertion. This is a minimal change to the Builder infrastructure that will later help support additional use cases such as block signature changes. `createBlock` now additionally takes the types of the block arguments that are added immediately so as to avoid in-place argument list manipulation that would be illegal in conversion patterns.
This commit is contained in:
parent
b600809688
commit
f27f1e8c27
|
@ -298,13 +298,15 @@ public:
|
|||
/// Insert the given operation at the current insertion point and return it.
|
||||
virtual Operation *insert(Operation *op);
|
||||
|
||||
/// Add new block and set the insertion point to the end of it. The block is
|
||||
/// inserted at the provided insertion point of 'parent'.
|
||||
Block *createBlock(Region *parent, Region::iterator insertPt = {});
|
||||
/// Add new block with 'argTypes' arguments and set the insertion point to the
|
||||
/// end of it. The block is inserted at the provided insertion point of
|
||||
/// 'parent'.
|
||||
virtual Block *createBlock(Region *parent, Region::iterator insertPt = {},
|
||||
TypeRange argTypes = llvm::None);
|
||||
|
||||
/// Add new block and set the insertion point to the end of it. The block is
|
||||
/// placed before 'insertBefore'.
|
||||
Block *createBlock(Block *insertBefore);
|
||||
/// Add new block with 'argTypes' arguments and set the insertion point to the
|
||||
/// end of it. The block is placed before 'insertBefore'.
|
||||
Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
|
||||
|
||||
/// Returns the current block of the builder.
|
||||
Block *getBlock() const { return block; }
|
||||
|
|
|
@ -344,6 +344,10 @@ public:
|
|||
/// otherwise an assert will be issued.
|
||||
void eraseOp(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for creating a new block with the given arguments.
|
||||
Block *createBlock(Region *parent, Region::iterator insertPt = {},
|
||||
TypeRange argTypes = llvm::None) override;
|
||||
|
||||
/// PatternRewriter hook for splitting a block into two parts.
|
||||
Block *splitBlock(Block *block, Block::iterator before) override;
|
||||
|
||||
|
|
|
@ -339,24 +339,28 @@ Operation *OpBuilder::insert(Operation *op) {
|
|||
return op;
|
||||
}
|
||||
|
||||
/// Add new block and set the insertion point to the end of it. The block is
|
||||
/// inserted at the provided insertion point of 'parent'.
|
||||
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
|
||||
/// Add new block with 'argTypes' arguments and set the insertion point to the
|
||||
/// end of it. The block is inserted at the provided insertion point of
|
||||
/// 'parent'.
|
||||
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
|
||||
TypeRange argTypes) {
|
||||
assert(parent && "expected valid parent region");
|
||||
if (insertPt == Region::iterator())
|
||||
insertPt = parent->end();
|
||||
|
||||
Block *b = new Block();
|
||||
b->addArguments(argTypes);
|
||||
parent->getBlocks().insert(insertPt, b);
|
||||
setInsertionPointToEnd(b);
|
||||
return b;
|
||||
}
|
||||
|
||||
/// Add new block and set the insertion point to the end of it. The block is
|
||||
/// placed before 'insertBefore'.
|
||||
Block *OpBuilder::createBlock(Block *insertBefore) {
|
||||
/// Add new block with 'argTypes' arguments and set the insertion point to the
|
||||
/// end of it. The block is placed before 'insertBefore'.
|
||||
Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
|
||||
assert(insertBefore && "expected valid insertion block");
|
||||
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
|
||||
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
|
||||
argTypes);
|
||||
}
|
||||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
|
|
|
@ -585,6 +585,9 @@ struct ConversionPatternRewriterImpl {
|
|||
/// PatternRewriter hook for replacing the results of an operation.
|
||||
void replaceOp(Operation *op, ValueRange newValues);
|
||||
|
||||
/// Notifies that a block was created.
|
||||
void notifyCreatedBlock(Block *block);
|
||||
|
||||
/// Notifies that a block was split.
|
||||
void notifySplitBlock(Block *block, Block *continuation);
|
||||
|
||||
|
@ -804,6 +807,10 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op,
|
|||
markNestedOpsIgnored(op);
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
|
||||
blockActions.push_back(BlockAction::getCreate(block));
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
|
||||
Block *continuation) {
|
||||
blockActions.push_back(BlockAction::getSplit(continuation, block));
|
||||
|
@ -910,6 +917,15 @@ Value ConversionPatternRewriter::getRemappedValue(Value key) {
|
|||
return impl->mapping.lookupOrDefault(key);
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for creating a new block with the given arguments.
|
||||
Block *ConversionPatternRewriter::createBlock(Region *parent,
|
||||
Region::iterator insertPtr,
|
||||
TypeRange argTypes) {
|
||||
Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes);
|
||||
impl->notifyCreatedBlock(block);
|
||||
return block;
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for splitting a block into two parts.
|
||||
Block *ConversionPatternRewriter::splitBlock(Block *block,
|
||||
Block::iterator before) {
|
||||
|
|
|
@ -130,6 +130,19 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) {
|
|||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @create_block
|
||||
func @create_block() {
|
||||
"test.container"() ({
|
||||
// Check that we created a block with arguments.
|
||||
// CHECK-NOT: test.create_block
|
||||
// CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
|
||||
// CHECK: test.finish
|
||||
"test.create_block"() : () -> ()
|
||||
"test.finish"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @fail_to_convert_illegal_op() -> i32 {
|
||||
|
@ -163,3 +176,17 @@ func @fail_to_convert_region() {
|
|||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @create_illegal_block
|
||||
func @create_illegal_block() {
|
||||
"test.container"() ({
|
||||
// Check that we can undo block creation, i.e. that the block was removed.
|
||||
// CHECK: test.create_illegal_block
|
||||
// CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
|
||||
"test.create_illegal_block"() : () -> ()
|
||||
"test.finish"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -183,6 +183,41 @@ struct TestRegionRewriteUndo : public RewritePattern {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
/// A simple pattern that creates a block at the end of the parent region of the
|
||||
/// matched operation.
|
||||
struct TestCreateBlock : public RewritePattern {
|
||||
TestCreateBlock(MLIRContext *ctx)
|
||||
: RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Region ®ion = *op->getParentRegion();
|
||||
Type i32Type = rewriter.getIntegerType(32);
|
||||
rewriter.createBlock(®ion, region.end(), {i32Type, i32Type});
|
||||
rewriter.create<TerminatorOp>(op->getLoc());
|
||||
rewriter.replaceOp(op, {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// A simple pattern that creates a block containing an invalid operaiton in
|
||||
/// order to trigger the block creation undo mechanism.
|
||||
struct TestCreateIllegalBlock : public RewritePattern {
|
||||
TestCreateIllegalBlock(MLIRContext *ctx)
|
||||
: RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Region ®ion = *op->getParentRegion();
|
||||
Type i32Type = rewriter.getIntegerType(32);
|
||||
rewriter.createBlock(®ion, region.end(), {i32Type, i32Type});
|
||||
// Create an illegal op to ensure the conversion fails.
|
||||
rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
|
||||
rewriter.create<TerminatorOp>(op->getLoc());
|
||||
rewriter.replaceOp(op, {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-Conversion Rewrite Testing
|
||||
|
@ -373,12 +408,12 @@ struct TestLegalizePatternDriver
|
|||
TestTypeConverter converter;
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
patterns
|
||||
.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestPassthroughInvalidOp, TestSplitReturnType,
|
||||
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
|
||||
TestNonRootReplacement>(&getContext());
|
||||
patterns.insert<
|
||||
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
|
||||
TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
|
||||
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
|
||||
TestNonRootReplacement>(&getContext());
|
||||
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||
converter);
|
||||
|
@ -388,7 +423,8 @@ struct TestLegalizePatternDriver
|
|||
// Define the conversion target used for the test.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
|
||||
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
|
||||
TerminatorOp>();
|
||||
target
|
||||
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
|
||||
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
|
||||
|
|
Loading…
Reference in New Issue