Add hook for dialect specializing processing blocks post inlining calls

This allows for dialects to do different post-processing depending on operations with the inliner (my use case requires different attribute propagation rules depending on call op). This hook runs before the regular processInlinedBlocks method.

Differential Revision: https://reviews.llvm.org/D104399
This commit is contained in:
Jacques Pienaar 2021-06-16 12:53:21 -07:00
parent e5813a683a
commit 0e760a0870
4 changed files with 74 additions and 26 deletions

View File

@ -140,6 +140,11 @@ public:
Location conversionLoc) const {
return nullptr;
}
/// Process a set of blocks that have been inlined for a call. This callback
/// is invoked before inlined terminator operations have been processed.
virtual void processInlinedCallBlocks(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {}
};
/// This interface provides the hooks into the inlining interface.
@ -178,6 +183,8 @@ public:
virtual void handleTerminator(Operation *op, Block *newDest) const;
virtual void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const;
virtual void processInlinedCallBlocks(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
};
//===----------------------------------------------------------------------===//
@ -209,8 +216,7 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
ValueRange inlinedOperands,
Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace,
Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);

View File

@ -106,6 +106,13 @@ void InlinerInterface::handleTerminator(Operation *op,
handler->handleTerminator(op, valuesToRepl);
}
void InlinerInterface::processInlinedCallBlocks(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
auto *handler = getInterfaceFor(call);
assert(handler && "expected valid dialect handler");
handler->processInlinedCallBlocks(call, inlinedBlocks);
}
/// Utility to check that all of the operations within 'src' can be inlined.
static bool isLegalToInline(InlinerInterface &interface, Region *src,
Region *insertRegion, bool shouldCloneInlinedRegion,
@ -137,13 +144,12 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
// Inline Methods
//===----------------------------------------------------------------------===//
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
BlockAndValueMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src,
Operation *inlinePoint, BlockAndValueMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
Operation *call) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
if (src->empty())
@ -198,6 +204,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
remapInlinedOperands(newBlocks, mapper);
// Process the newly inlined blocks.
if (call)
interface.processInlinedCallBlocks(call, newBlocks);
interface.processInlinedBlocks(newBlocks);
// Handle the case where only a single block was inlined.
@ -232,15 +240,11 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
return success();
}
/// This function is an overload of the above 'inlineRegion' that allows for
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src,
Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, Operation *call) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
@ -261,9 +265,33 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
}
// Call into the main region inliner function.
return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
resultsToReplace.getTypes(), inlineLoc,
shouldCloneInlinedRegion);
return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
resultsToReplace.getTypes(), inlineLoc,
shouldCloneInlinedRegion, call);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
BlockAndValueMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
regionResultTypes, inlineLoc,
shouldCloneInlinedRegion,
/*call=*/nullptr);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands,
resultsToReplace, inlineLoc, shouldCloneInlinedRegion,
/*call=*/nullptr);
}
/// Utility function used to generate a cast operation from the given interface,
@ -371,9 +399,9 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState();
// Attempt to inline the call.
if (failed(inlineRegion(interface, src, call, mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion)))
if (failed(inlineRegionImpl(interface, src, call, mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call)))
return cleanupState();
return success();
}

View File

@ -140,9 +140,9 @@ func @convert_callee_fn_multiblock() -> i32 {
// CHECK-LABEL: func @inline_convert_result_multiblock
func @inline_convert_result_multiblock() -> i16 {
// CHECK: br ^bb1
// CHECK: br ^bb1 {inlined_conversion}
// CHECK: ^bb1:
// CHECK: %[[C:.+]] = constant 0 : i32
// CHECK: %[[C:.+]] = constant {inlined_conversion} 0 : i32
// CHECK: br ^bb2(%[[C]] : i32)
// CHECK: ^bb2(%[[BBARG:.+]]: i32):
// CHECK: %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16

View File

@ -171,6 +171,20 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return nullptr;
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
void processInlinedCallBlocks(
Operation *call,
iterator_range<Region::iterator> inlinedBlocks) const final {
if (!isa<ConversionCallOp>(call))
return;
// Set attributed on all ops in the inlined blocks.
for (Block &block : inlinedBlocks) {
block.walk([&](Operation *op) {
op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
});
}
}
};
struct TestReductionPatternInterface : public DialectReductionPatternInterface {