[mlir][NFC] Add inlineRegion overloads that take a block iterator insert position

This allows for inlining into an empty block or to the beginning of a block. NFC as the existing implementations now foward to this overload.

Differential Revision: https://reviews.llvm.org/D108572
This commit is contained in:
River Riddle 2021-08-23 19:49:38 +00:00
parent e8723abf43
commit da12d88b1c
2 changed files with 53 additions and 29 deletions

View File

@ -211,6 +211,13 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
TypeRange regionResultTypes, TypeRange regionResultTypes,
Optional<Location> inlineLoc = llvm::None, Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true); bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
BlockAndValueMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
/// This function is an overload of the above 'inlineRegion' that allows for /// This function is an overload of the above 'inlineRegion' that allows for
/// providing the set of operands ('inlinedOperands') that should be used /// providing the set of operands ('inlinedOperands') that should be used
@ -220,6 +227,12 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
ValueRange resultsToReplace, ValueRange resultsToReplace,
Optional<Location> inlineLoc = llvm::None, Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true); bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
/// This function inlines a given region, 'src', of a callable operation, /// This function inlines a given region, 'src', of a callable operation,
/// 'callable', into the location defined by the given call operation. This /// 'callable', into the location defined by the given call operation. This

View File

@ -145,11 +145,11 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Operation *inlinePoint, BlockAndValueMapping &mapper, Block::iterator inlinePoint, BlockAndValueMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes, ValueRange resultsToReplace, TypeRange regionResultTypes,
Optional<Location> inlineLoc, bool shouldCloneInlinedRegion, Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
Operation *call) { Operation *call = nullptr) {
assert(resultsToReplace.size() == regionResultTypes.size()); assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block. // We expect the region to have at least one block.
if (src->empty()) if (src->empty())
@ -161,26 +161,18 @@ inlineRegionImpl(InlinerInterface &interface, Region *src,
[&](BlockArgument arg) { return !mapper.contains(arg); })) [&](BlockArgument arg) { return !mapper.contains(arg); }))
return failure(); return failure();
// The insertion point must be within a block.
Block *insertBlock = inlinePoint->getBlock();
if (!insertBlock)
return failure();
Region *insertRegion = insertBlock->getParent();
// Check that the operations within the source region are valid to inline. // Check that the operations within the source region are valid to inline.
Region *insertRegion = inlineBlock->getParent();
if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
mapper) || mapper) ||
!isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
mapper)) mapper))
return failure(); return failure();
// Split the insertion block.
Block *postInsertBlock =
insertBlock->splitBlock(++inlinePoint->getIterator());
// Check to see if the region is being cloned, or moved inline. In either // Check to see if the region is being cloned, or moved inline. In either
// case, move the new blocks after the 'insertBlock' to improve IR // case, move the new blocks after the 'insertBlock' to improve IR
// readability. // readability.
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
if (shouldCloneInlinedRegion) if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else else
@ -189,7 +181,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src,
src->end()); src->end());
// Get the range of newly inserted blocks. // Get the range of newly inserted blocks.
auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
postInsertBlock->getIterator()); postInsertBlock->getIterator());
Block *firstNewBlock = &*newBlocks.begin(); Block *firstNewBlock = &*newBlocks.begin();
@ -234,17 +226,17 @@ inlineRegionImpl(InlinerInterface &interface, Region *src,
} }
// Splice the instructions of the inlined entry block into the insert block. // Splice the instructions of the inlined entry block into the insert block.
insertBlock->getOperations().splice(insertBlock->end(), inlineBlock->getOperations().splice(inlineBlock->end(),
firstNewBlock->getOperations()); firstNewBlock->getOperations());
firstNewBlock->erase(); firstNewBlock->erase();
return success(); return success();
} }
static LogicalResult static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Operation *inlinePoint, ValueRange inlinedOperands, Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, Optional<Location> inlineLoc, ValueRange resultsToReplace, Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, Operation *call) { bool shouldCloneInlinedRegion, Operation *call = nullptr) {
// We expect the region to have at least one block. // We expect the region to have at least one block.
if (src->empty()) if (src->empty())
return failure(); return failure();
@ -265,9 +257,9 @@ inlineRegionImpl(InlinerInterface &interface, Region *src,
} }
// Call into the main region inliner function. // Call into the main region inliner function.
return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace, return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace.getTypes(), inlineLoc, resultsToReplace, resultsToReplace.getTypes(),
shouldCloneInlinedRegion, call); inlineLoc, shouldCloneInlinedRegion, call);
} }
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
@ -277,10 +269,19 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
TypeRange regionResultTypes, TypeRange regionResultTypes,
Optional<Location> inlineLoc, Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) { bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace, return inlineRegion(interface, src, inlinePoint->getBlock(),
regionResultTypes, inlineLoc, ++inlinePoint->getIterator(), mapper, resultsToReplace,
shouldCloneInlinedRegion, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
/*call=*/nullptr); }
LogicalResult
mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, BlockAndValueMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, regionResultTypes, inlineLoc,
shouldCloneInlinedRegion);
} }
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
@ -289,9 +290,18 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
ValueRange resultsToReplace, ValueRange resultsToReplace,
Optional<Location> inlineLoc, Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) { bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands, return inlineRegion(interface, src, inlinePoint->getBlock(),
resultsToReplace, inlineLoc, shouldCloneInlinedRegion, ++inlinePoint->getIterator(), inlinedOperands,
/*call=*/nullptr); resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult
mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
inlinedOperands, resultsToReplace, inlineLoc,
shouldCloneInlinedRegion);
} }
/// Utility function used to generate a cast operation from the given interface, /// Utility function used to generate a cast operation from the given interface,
@ -399,7 +409,8 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState(); return cleanupState();
// Attempt to inline the call. // Attempt to inline the call.
if (failed(inlineRegionImpl(interface, src, call, mapper, callResults, if (failed(inlineRegionImpl(interface, src, call->getBlock(),
++call->getIterator(), mapper, callResults,
callableResultTypes, call.getLoc(), callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call))) shouldCloneInlinedRegion, call)))
return cleanupState(); return cleanupState();