[mlir][bufferize] Add isRepetitiveRegion to BufferizableOpInterface

This method allows to declare regions as "repetitive" even if the parent op does not implement the RegionBranchOpInterface.

This is needed to support loop-like ops that have parallel semantics but do not branch between regions.

Differential Revision: https://reviews.llvm.org/D133113
This commit is contained in:
Matthias Springer 2022-09-02 14:32:04 +02:00
parent 3da23970ed
commit f7f0c7f7e3
6 changed files with 82 additions and 16 deletions

View File

@ -566,6 +566,12 @@ namespace detail {
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const DenseMap<Value, BaseMemRefType> &fixedTypes);
/// This is the default implementation of
/// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other
/// places.
bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
unsigned index);
} // namespace detail
} // namespace bufferization

View File

@ -360,6 +360,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
value, options, fixedTypes);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return `true` if the given region of this op is repetitive. By default
this information is queried from the `RegionBranchOpInterface`. Ops
that do not implement this inferface can override this method to
declare regions as repetitive.
The RaW conflict detection of One-Shot Analysis is more strict inside
repetitive regions: Op dominance cannot always be used to rule out
certain potential conflicts (e.g., a conflicting write happening after
a read), because there may not be a meaningful ordering of certain ops
that are executed multiple times. This is described in more detail in
documentation of One-Shot Analysis.
}],
/*retType=*/"bool",
/*methodName=*/"isRepetitiveRegion",
/*args=*/(ins "unsigned":$index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return mlir::bufferization::detail::defaultIsRepetitiveRegion(
cast<BufferizableOpInterface>($_op.getOperation()), index);
}]
>
];
let extraClassDeclaration = [{

View File

@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/Support/Debug.h"
//===----------------------------------------------------------------------===//
@ -784,3 +785,13 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
rankedTensorType.getElementType(), layout,
memorySpaceAttr);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
BufferizableOpInterface bufferizableOp, unsigned index) {
assert(index < bufferizableOp->getNumRegions() && "invalid region index");
auto regionInterface =
dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
if (!regionInterface)
return false;
return regionInterface.isRepetitiveRegion(index);
}

View File

@ -351,14 +351,40 @@ static bool happensBefore(Operation *a, Operation *b,
return false;
}
static Region *
getEnclosingRepetitiveRegion(Operation *op,
const BufferizationOptions &options) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
}
return nullptr;
}
static Region *
getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) {
Region *region = value.getParentRegion();
while (region) {
Operation *op = region->getParentOp();
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
region = op->getParentRegion();
}
return nullptr;
}
/// For each given value, find the closest enclosing repetitive region. If this
/// is the same region for each value, return it. Otherwise return None.
/// Note: If there is no enclosing repetitive region, return nullptr.
static Optional<Region *>
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values,
const BufferizationOptions &options) {
if (values.empty())
return None;
Region *r = getEnclosingRepetitiveRegion(values.front());
Region *r = getEnclosingRepetitiveRegion(values.front(), options);
for (Value value : values.drop_front())
if (getEnclosingRepetitiveRegion(value) != r)
return None;
@ -432,7 +458,7 @@ static bool hasReadAfterWriteInterference(
// Find the inner-most enclosing repetitive region of each alias. If this is
// the same region for every alias, save it in `repetitiveRegionOfWrites`.
Optional<Region *> repetitiveRegionOfWrites =
getCommonEnclosingRepetitiveRegion(writtenAliases);
getCommonEnclosingRepetitiveRegion(writtenAliases, options);
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@ -497,7 +523,7 @@ static bool hasReadAfterWriteInterference(
bool canUseOpDominance =
writtenAliases.empty() ||
repetitiveRegionOfWrites ==
getEnclosingRepetitiveRegion(conflictingWritingOp);
getEnclosingRepetitiveRegion(conflictingWritingOp, options);
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
// write is not visible when reading.

View File

@ -48,15 +48,14 @@ resolveUsesInRepetitiveRegions(Operation *op,
AnalysisState state(options);
// Look for repetitive ops (loops).
op->walk([&](RegionBranchOpInterface regionBranchOp) {
// Skip non-bufferizable ops.
auto bufferizableOp = options.dynCastBufferizableOp(regionBranchOp);
if (!bufferizableOp)
op->walk([&](BufferizableOpInterface bufferizableOp) {
// Skip filtered ops.
if (!options.isOpAllowed(bufferizableOp.getOperation()))
return WalkResult::advance();
// Find all operands that are also used inside of a repetitve region of this
// op.
for (OpOperand &opOperand : regionBranchOp->getOpOperands()) {
// Find all operands that are also used inside of a repetitive region of
// this op.
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
Value operand = opOperand.get();
// Skip non-tensor operands.
if (!operand.getType().isa<TensorType>())
@ -69,11 +68,11 @@ resolveUsesInRepetitiveRegions(Operation *op,
SmallVector<OpOperand *> usesInsideRegion;
for (OpOperand &use : operand.getUses()) {
Operation *owner = use.getOwner();
if (!regionBranchOp->isProperAncestor(owner))
if (!bufferizableOp->isProperAncestor(owner))
continue;
for (Region &r : regionBranchOp->getRegions()) {
for (Region &r : bufferizableOp->getRegions()) {
if (r.findAncestorOpInRegion(*owner) &&
regionBranchOp.isRepetitiveRegion(r.getRegionNumber())) {
bufferizableOp.isRepetitiveRegion(r.getRegionNumber())) {
usesInsideRegion.push_back(&use);
break;
}
@ -84,9 +83,9 @@ resolveUsesInRepetitiveRegions(Operation *op,
continue;
// Insert a tensor copy and replace all uses inside of repetitive regions.
rewriter.setInsertionPoint(regionBranchOp);
rewriter.setInsertionPoint(bufferizableOp);
auto tensorCopy = rewriter.create<AllocTensorOp>(
regionBranchOp->getLoc(), operand.getType().cast<TensorType>(),
bufferizableOp->getLoc(), operand.getType().cast<TensorType>(),
/*dynamicSizes=*/ValueRange(),
/*copy=*/operand, /*memory_space=*/IntegerAttr());
for (OpOperand *use : usesInsideRegion)

View File

@ -9088,6 +9088,7 @@ cc_library(
":BufferizableOpInterfaceIncGen",
":BufferizationBaseIncGen",
":BufferizationOpsIncGen",
":ControlFlowInterfaces",
":CopyOpInterface",
":FuncDialect",
":IR",