forked from OSchip/llvm-project
[mlir][bufferize] Add isRepetitiveRegion to BufferizableOpInterface
This method allows to declare regions as "repetitive" even if the parent op does not implement the RegionBranchOpInterface. This is needed to support loop-like ops that have parallel semantics but do not branch between regions. Differential Revision: https://reviews.llvm.org/D133113
This commit is contained in:
parent
3da23970ed
commit
f7f0c7f7e3
|
@ -566,6 +566,12 @@ namespace detail {
|
|||
FailureOr<BaseMemRefType>
|
||||
defaultGetBufferType(Value value, const BufferizationOptions &options,
|
||||
const DenseMap<Value, BaseMemRefType> &fixedTypes);
|
||||
|
||||
/// This is the default implementation of
|
||||
/// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other
|
||||
/// places.
|
||||
bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
|
||||
unsigned index);
|
||||
} // namespace detail
|
||||
|
||||
} // namespace bufferization
|
||||
|
|
|
@ -360,6 +360,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
value, options, fixedTypes);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return `true` if the given region of this op is repetitive. By default
|
||||
this information is queried from the `RegionBranchOpInterface`. Ops
|
||||
that do not implement this inferface can override this method to
|
||||
declare regions as repetitive.
|
||||
|
||||
The RaW conflict detection of One-Shot Analysis is more strict inside
|
||||
repetitive regions: Op dominance cannot always be used to rule out
|
||||
certain potential conflicts (e.g., a conflicting write happening after
|
||||
a read), because there may not be a meaningful ordering of certain ops
|
||||
that are executed multiple times. This is described in more detail in
|
||||
documentation of One-Shot Analysis.
|
||||
}],
|
||||
/*retType=*/"bool",
|
||||
/*methodName=*/"isRepetitiveRegion",
|
||||
/*args=*/(ins "unsigned":$index),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return mlir::bufferization::detail::defaultIsRepetitiveRegion(
|
||||
cast<BufferizableOpInterface>($_op.getOperation()), index);
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -784,3 +785,13 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
|
|||
rankedTensorType.getElementType(), layout,
|
||||
memorySpaceAttr);
|
||||
}
|
||||
|
||||
bool bufferization::detail::defaultIsRepetitiveRegion(
|
||||
BufferizableOpInterface bufferizableOp, unsigned index) {
|
||||
assert(index < bufferizableOp->getNumRegions() && "invalid region index");
|
||||
auto regionInterface =
|
||||
dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
|
||||
if (!regionInterface)
|
||||
return false;
|
||||
return regionInterface.isRepetitiveRegion(index);
|
||||
}
|
||||
|
|
|
@ -351,14 +351,40 @@ static bool happensBefore(Operation *a, Operation *b,
|
|||
return false;
|
||||
}
|
||||
|
||||
static Region *
|
||||
getEnclosingRepetitiveRegion(Operation *op,
|
||||
const BufferizationOptions &options) {
|
||||
while (Region *region = op->getParentRegion()) {
|
||||
op = region->getParentOp();
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
|
||||
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
|
||||
return region;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Region *
|
||||
getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) {
|
||||
Region *region = value.getParentRegion();
|
||||
while (region) {
|
||||
Operation *op = region->getParentOp();
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
|
||||
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
|
||||
return region;
|
||||
region = op->getParentRegion();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// For each given value, find the closest enclosing repetitive region. If this
|
||||
/// is the same region for each value, return it. Otherwise return None.
|
||||
/// Note: If there is no enclosing repetitive region, return nullptr.
|
||||
static Optional<Region *>
|
||||
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
|
||||
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values,
|
||||
const BufferizationOptions &options) {
|
||||
if (values.empty())
|
||||
return None;
|
||||
Region *r = getEnclosingRepetitiveRegion(values.front());
|
||||
Region *r = getEnclosingRepetitiveRegion(values.front(), options);
|
||||
for (Value value : values.drop_front())
|
||||
if (getEnclosingRepetitiveRegion(value) != r)
|
||||
return None;
|
||||
|
@ -432,7 +458,7 @@ static bool hasReadAfterWriteInterference(
|
|||
// Find the inner-most enclosing repetitive region of each alias. If this is
|
||||
// the same region for every alias, save it in `repetitiveRegionOfWrites`.
|
||||
Optional<Region *> repetitiveRegionOfWrites =
|
||||
getCommonEnclosingRepetitiveRegion(writtenAliases);
|
||||
getCommonEnclosingRepetitiveRegion(writtenAliases, options);
|
||||
|
||||
for (OpOperand *uRead : usesRead) {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
|
@ -497,7 +523,7 @@ static bool hasReadAfterWriteInterference(
|
|||
bool canUseOpDominance =
|
||||
writtenAliases.empty() ||
|
||||
repetitiveRegionOfWrites ==
|
||||
getEnclosingRepetitiveRegion(conflictingWritingOp);
|
||||
getEnclosingRepetitiveRegion(conflictingWritingOp, options);
|
||||
|
||||
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
|
||||
// write is not visible when reading.
|
||||
|
|
|
@ -48,15 +48,14 @@ resolveUsesInRepetitiveRegions(Operation *op,
|
|||
AnalysisState state(options);
|
||||
|
||||
// Look for repetitive ops (loops).
|
||||
op->walk([&](RegionBranchOpInterface regionBranchOp) {
|
||||
// Skip non-bufferizable ops.
|
||||
auto bufferizableOp = options.dynCastBufferizableOp(regionBranchOp);
|
||||
if (!bufferizableOp)
|
||||
op->walk([&](BufferizableOpInterface bufferizableOp) {
|
||||
// Skip filtered ops.
|
||||
if (!options.isOpAllowed(bufferizableOp.getOperation()))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Find all operands that are also used inside of a repetitve region of this
|
||||
// op.
|
||||
for (OpOperand &opOperand : regionBranchOp->getOpOperands()) {
|
||||
// Find all operands that are also used inside of a repetitive region of
|
||||
// this op.
|
||||
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
||||
Value operand = opOperand.get();
|
||||
// Skip non-tensor operands.
|
||||
if (!operand.getType().isa<TensorType>())
|
||||
|
@ -69,11 +68,11 @@ resolveUsesInRepetitiveRegions(Operation *op,
|
|||
SmallVector<OpOperand *> usesInsideRegion;
|
||||
for (OpOperand &use : operand.getUses()) {
|
||||
Operation *owner = use.getOwner();
|
||||
if (!regionBranchOp->isProperAncestor(owner))
|
||||
if (!bufferizableOp->isProperAncestor(owner))
|
||||
continue;
|
||||
for (Region &r : regionBranchOp->getRegions()) {
|
||||
for (Region &r : bufferizableOp->getRegions()) {
|
||||
if (r.findAncestorOpInRegion(*owner) &&
|
||||
regionBranchOp.isRepetitiveRegion(r.getRegionNumber())) {
|
||||
bufferizableOp.isRepetitiveRegion(r.getRegionNumber())) {
|
||||
usesInsideRegion.push_back(&use);
|
||||
break;
|
||||
}
|
||||
|
@ -84,9 +83,9 @@ resolveUsesInRepetitiveRegions(Operation *op,
|
|||
continue;
|
||||
|
||||
// Insert a tensor copy and replace all uses inside of repetitive regions.
|
||||
rewriter.setInsertionPoint(regionBranchOp);
|
||||
rewriter.setInsertionPoint(bufferizableOp);
|
||||
auto tensorCopy = rewriter.create<AllocTensorOp>(
|
||||
regionBranchOp->getLoc(), operand.getType().cast<TensorType>(),
|
||||
bufferizableOp->getLoc(), operand.getType().cast<TensorType>(),
|
||||
/*dynamicSizes=*/ValueRange(),
|
||||
/*copy=*/operand, /*memory_space=*/IntegerAttr());
|
||||
for (OpOperand *use : usesInsideRegion)
|
||||
|
|
|
@ -9088,6 +9088,7 @@ cc_library(
|
|||
":BufferizableOpInterfaceIncGen",
|
||||
":BufferizationBaseIncGen",
|
||||
":BufferizationOpsIncGen",
|
||||
":ControlFlowInterfaces",
|
||||
":CopyOpInterface",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
|
|
Loading…
Reference in New Issue