diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index a86a6b9cb08e..8dcc1f5eb699 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -140,6 +140,11 @@ public: Location conversionLoc) const { return nullptr; } + + /// Process a set of blocks that have been inlined for a call. This callback + /// is invoked before inlined terminator operations have been processed. + virtual void processInlinedCallBlocks( + Operation *call, iterator_range inlinedBlocks) const {} }; /// This interface provides the hooks into the inlining interface. @@ -178,6 +183,8 @@ public: virtual void handleTerminator(Operation *op, Block *newDest) const; virtual void handleTerminator(Operation *op, ArrayRef valuesToRepl) const; + virtual void processInlinedCallBlocks( + Operation *call, iterator_range inlinedBlocks) const; }; //===----------------------------------------------------------------------===// @@ -209,8 +216,7 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src, /// providing the set of operands ('inlinedOperands') that should be used /// in-favor of the region arguments when inlining. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, - ValueRange inlinedOperands, + Operation *inlinePoint, ValueRange inlinedOperands, ValueRange resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 7d18de076e4b..5b50d212fb07 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -106,6 +106,13 @@ void InlinerInterface::handleTerminator(Operation *op, handler->handleTerminator(op, valuesToRepl); } +void InlinerInterface::processInlinedCallBlocks( + Operation *call, iterator_range inlinedBlocks) const { + auto *handler = getInterfaceFor(call); + assert(handler && "expected valid dialect handler"); + handler->processInlinedCallBlocks(call, inlinedBlocks); +} + /// Utility to check that all of the operations within 'src' can be inlined. static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, @@ -137,13 +144,12 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, // Inline Methods //===----------------------------------------------------------------------===// -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, - BlockAndValueMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - Optional inlineLoc, - bool shouldCloneInlinedRegion) { +static LogicalResult +inlineRegionImpl(InlinerInterface &interface, Region *src, + Operation *inlinePoint, BlockAndValueMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + Optional inlineLoc, bool shouldCloneInlinedRegion, + Operation *call) { assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. if (src->empty()) @@ -198,6 +204,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, remapInlinedOperands(newBlocks, mapper); // Process the newly inlined blocks. + if (call) + interface.processInlinedCallBlocks(call, newBlocks); interface.processInlinedBlocks(newBlocks); // Handle the case where only a single block was inlined. @@ -232,15 +240,11 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, return success(); } -/// This function is an overload of the above 'inlineRegion' that allows for -/// providing the set of operands ('inlinedOperands') that should be used -/// in-favor of the region arguments when inlining. -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - Optional inlineLoc, - bool shouldCloneInlinedRegion) { +static LogicalResult +inlineRegionImpl(InlinerInterface &interface, Region *src, + Operation *inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, Optional inlineLoc, + bool shouldCloneInlinedRegion, Operation *call) { // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -261,9 +265,33 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, } // Call into the main region inliner function. - return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, - resultsToReplace.getTypes(), inlineLoc, - shouldCloneInlinedRegion); + return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace, + resultsToReplace.getTypes(), inlineLoc, + shouldCloneInlinedRegion, call); +} + +LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, + BlockAndValueMapping &mapper, + ValueRange resultsToReplace, + TypeRange regionResultTypes, + Optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace, + regionResultTypes, inlineLoc, + shouldCloneInlinedRegion, + /*call=*/nullptr); +} + +LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, + ValueRange inlinedOperands, + ValueRange resultsToReplace, + Optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands, + resultsToReplace, inlineLoc, shouldCloneInlinedRegion, + /*call=*/nullptr); } /// Utility function used to generate a cast operation from the given interface, @@ -371,9 +399,9 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, return cleanupState(); // Attempt to inline the call. - if (failed(inlineRegion(interface, src, call, mapper, callResults, - callableResultTypes, call.getLoc(), - shouldCloneInlinedRegion))) + if (failed(inlineRegionImpl(interface, src, call, mapper, callResults, + callableResultTypes, call.getLoc(), + shouldCloneInlinedRegion, call))) return cleanupState(); return success(); } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index d568be0429a9..e0368b25a2d2 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -140,9 +140,9 @@ func @convert_callee_fn_multiblock() -> i32 { // CHECK-LABEL: func @inline_convert_result_multiblock func @inline_convert_result_multiblock() -> i16 { -// CHECK: br ^bb1 +// CHECK: br ^bb1 {inlined_conversion} // CHECK: ^bb1: -// CHECK: %[[C:.+]] = constant 0 : i32 +// CHECK: %[[C:.+]] = constant {inlined_conversion} 0 : i32 // CHECK: br ^bb2(%[[C]] : i32) // CHECK: ^bb2(%[[BBARG:.+]]: i32): // CHECK: %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index a21e32a12eff..8ef6ec6000c6 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -171,6 +171,20 @@ struct TestInlinerInterface : public DialectInlinerInterface { return nullptr; return builder.create(conversionLoc, resultType, input); } + + void processInlinedCallBlocks( + Operation *call, + iterator_range inlinedBlocks) const final { + if (!isa(call)) + return; + + // Set attributed on all ops in the inlined blocks. + for (Block &block : inlinedBlocks) { + block.walk([&](Operation *op) { + op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); + }); + } + } }; struct TestReductionPatternInterface : public DialectReductionPatternInterface {