forked from OSchip/llvm-project
[mlir][linalg][bufferize] Add dialect filter to BufferizationOptions
This adds a new option `dialectFilter` to BufferizationOptions. Only ops from dialects that are allow-listed in the filter are bufferized. Other ops are left unbufferized. Note: This option requires `allowUnknownOps = true`. To make use of `dialectFilter`, BufferizationOptions or BufferizationState must be passed to various helper functions. The purpose of this change is to provide a better infrastructure for partial bufferization, which will be fully activated in a subsequent change. Differential Revision: https://reviews.llvm.org/D114691
This commit is contained in:
parent
84687405ce
commit
847710f7b7
|
@ -30,6 +30,7 @@ namespace comprehensive_bufferize {
|
|||
static constexpr int64_t kBufferAlignments = 128;
|
||||
|
||||
class BufferizationAliasInfo;
|
||||
class BufferizableOpInterface;
|
||||
struct BufferizationOptions;
|
||||
class BufferizationState;
|
||||
struct PostAnalysisStep;
|
||||
|
@ -92,6 +93,33 @@ struct BufferizationOptions {
|
|||
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Return `true` if the op is allowed to be bufferized.
|
||||
bool isOpAllowed(Operation *op) const {
|
||||
if (!dialectFilter.hasValue())
|
||||
return true;
|
||||
return dialectFilter->contains(op->getDialect()->getNamespace());
|
||||
}
|
||||
|
||||
/// Allow-list the given dialects in the dialect filter. Only ops from
|
||||
/// allow-listed dialects will be bufferized.
|
||||
template <typename... DialectTs>
|
||||
void addToDialectFilter() {
|
||||
// The following expands a call to addToDialectFilterImpl for each dialect
|
||||
// in 'DialectTs'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{
|
||||
0, (addToDialectFilterImpl<DialectTs>(), 0)...};
|
||||
}
|
||||
|
||||
/// Try to cast the given op to BufferizableOpInterface if the op is allow
|
||||
/// listed.
|
||||
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
|
||||
|
||||
/// Try to cast the given value to BufferizableOpInterface if the op is allow
|
||||
/// listed.
|
||||
BufferizableOpInterface dynCastBufferizableOp(Value value) const;
|
||||
|
||||
/// Helper functions for allocation, deallocation, memory copying.
|
||||
std::unique_ptr<AllocationCallbacks> allocationFns;
|
||||
|
||||
|
@ -114,6 +142,25 @@ struct BufferizationOptions {
|
|||
|
||||
/// Registered post analysis steps.
|
||||
PostAnalysisStepList postAnalysisSteps;
|
||||
|
||||
/// Only bufferize ops from dialects that are allowed-listed by the filter.
|
||||
/// All other ops are ignored. This option controls the scope of partial
|
||||
/// bufferization.
|
||||
///
|
||||
/// Note: If no filter is specified, all ops are bufferized (as long as they
|
||||
/// implement BufferizableOpInterface). If a filter is specified,
|
||||
/// `allowUnknownOps` should be enabled. Otherwise, bufferization would fail
|
||||
/// when encountering an op that is forbidden by the filter.
|
||||
Optional<DenseSet<StringRef>> dialectFilter;
|
||||
|
||||
private:
|
||||
/// Allow-list a dialect in the dialect filter.
|
||||
template <typename DialectT>
|
||||
void addToDialectFilterImpl() {
|
||||
if (!dialectFilter.hasValue())
|
||||
dialectFilter.emplace();
|
||||
dialectFilter->insert(DialectT::getDialectNamespace());
|
||||
}
|
||||
};
|
||||
|
||||
/// Specify fine-grain relationship between buffers to enable more analysis.
|
||||
|
@ -128,7 +175,8 @@ enum class BufferRelation {
|
|||
/// equivalence classes to support bufferization.
|
||||
class BufferizationAliasInfo {
|
||||
public:
|
||||
explicit BufferizationAliasInfo(Operation *rootOp);
|
||||
explicit BufferizationAliasInfo(Operation *rootOp,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
// BufferizationAliasInfo should be passed as a reference.
|
||||
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
|
||||
|
@ -265,7 +313,7 @@ bool isValueRead(Value value);
|
|||
/// starting the traversal from Value 1, the resulting SetVector is:
|
||||
/// { 2, 7, 8, 5 }
|
||||
llvm::SetVector<Value>
|
||||
findValueInReverseUseDefChain(Value value,
|
||||
findValueInReverseUseDefChain(Value value, const BufferizationOptions &options,
|
||||
std::function<bool(Value)> condition);
|
||||
|
||||
/// Find the Value of the last preceding write of a given Value.
|
||||
|
@ -276,7 +324,7 @@ findValueInReverseUseDefChain(Value value,
|
|||
///
|
||||
/// Note: When reaching an end of the reverse SSA use-def chain, that value
|
||||
/// is returned regardless of whether it is a memory write or not.
|
||||
Value findLastPrecedingWrite(Value value);
|
||||
Value findLastPrecedingWrite(Value value, const BufferizationOptions &options);
|
||||
|
||||
/// Dialect-specific bufferization state. Analysis/bufferization information
|
||||
/// that is specific to ops from a certain dialect can be stored in derived
|
||||
|
@ -300,7 +348,7 @@ struct DialectBufferizationState {
|
|||
class BufferizationState {
|
||||
public:
|
||||
BufferizationState(Operation *op, const BufferizationOptions &options)
|
||||
: aliasInfo(op), options(options), builder(op->getContext()) {}
|
||||
: aliasInfo(op, options), options(options), builder(op->getContext()) {}
|
||||
|
||||
// BufferizationState should be passed as a reference.
|
||||
BufferizationState(const BufferizationState &) = delete;
|
||||
|
|
|
@ -266,6 +266,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*methodName=*/"isNotConflicting",
|
||||
/*args=*/(ins "OpOperand *":$uRead,
|
||||
"OpOperand *":$uWrite,
|
||||
"BufferizationState &":$state,
|
||||
"const BufferizationAliasInfo &":$aliasInfo),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
|
|
|
@ -78,7 +78,8 @@ BufferizationOptions::BufferizationOptions()
|
|||
// BufferizationAliasInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
||||
BufferizationAliasInfo::BufferizationAliasInfo(
|
||||
Operation *rootOp, const BufferizationOptions &options) {
|
||||
rootOp->walk([&](Operation *op) {
|
||||
for (Value v : op->getResults())
|
||||
if (v.getType().isa<TensorType>())
|
||||
|
@ -93,6 +94,8 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
|||
// Set up alias sets for OpResults that must bufferize in-place. This should
|
||||
// be done before making any other bufferization decisions.
|
||||
rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
|
||||
if (!options.isOpAllowed(bufferizableOp))
|
||||
return WalkResult::skip();
|
||||
for (OpResult opResult : bufferizableOp->getOpResults()) {
|
||||
if (opResult.getType().isa<TensorType>())
|
||||
if (bufferizableOp.mustBufferizeInPlace(opResult)) {
|
||||
|
@ -105,6 +108,7 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
|||
markInPlace(opResult);
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -197,6 +201,21 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
|
|||
}
|
||||
}
|
||||
|
||||
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
|
||||
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
|
||||
if (isOpAllowed(op))
|
||||
return dyn_cast<BufferizableOpInterface>(op);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
|
||||
BufferizationOptions::dynCastBufferizableOp(Value value) const {
|
||||
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
|
||||
if (isOpAllowed(bufferizableOp.getOperation()))
|
||||
return bufferizableOp;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Determine which OpOperand* will alias with `result` if the op is bufferized
|
||||
/// in place. Return an empty vector if the op is not bufferizable.
|
||||
SmallVector<OpOperand *>
|
||||
|
@ -283,7 +302,8 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
|
|||
// further.
|
||||
llvm::SetVector<Value>
|
||||
mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
|
||||
Value value, std::function<bool(Value)> condition) {
|
||||
Value value, const BufferizationOptions &options,
|
||||
std::function<bool(Value)> condition) {
|
||||
llvm::SetVector<Value> result, workingSet;
|
||||
workingSet.insert(value);
|
||||
|
||||
|
@ -296,7 +316,7 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
|
|||
|
||||
OpResult opResult = value.cast<OpResult>();
|
||||
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
|
||||
if (opOperands.empty()) {
|
||||
if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
|
||||
result.insert(value);
|
||||
continue;
|
||||
}
|
||||
|
@ -310,13 +330,13 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
|
|||
|
||||
// Find the Value of the last preceding write of a given Value.
|
||||
Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
|
||||
Value value) {
|
||||
Value value, const BufferizationOptions &options) {
|
||||
SetVector<Value> result =
|
||||
findValueInReverseUseDefChain(value, [](Value value) {
|
||||
findValueInReverseUseDefChain(value, options, [&](Value value) {
|
||||
Operation *op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return true;
|
||||
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
|
||||
auto bufferizableOp = options.dynCastBufferizableOp(op);
|
||||
if (!bufferizableOp)
|
||||
return true;
|
||||
return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
|
||||
|
@ -374,9 +394,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
|
||||
// use-def chain, it returns that value, regardless of whether it is a
|
||||
// memory write or not.
|
||||
Value lastWrite = findLastPrecedingWrite(operand);
|
||||
if (auto bufferizableOp =
|
||||
lastWrite.getDefiningOp<BufferizableOpInterface>())
|
||||
Value lastWrite = findLastPrecedingWrite(operand, options);
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
|
||||
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
|
||||
skipCopy = true;
|
||||
// Do not copy if the copied data is never read.
|
||||
|
@ -433,7 +452,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
|
|||
|
||||
// Bufferize using `BufferizableOpInterface`. Interface implementations are
|
||||
// responsible for bufferizing nested ops.
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
|
||||
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
|
||||
b.setInsertionPoint(op);
|
||||
return bufferizableOp.bufferize(b, state);
|
||||
}
|
||||
|
@ -640,8 +659,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
|||
if (options.allowUnknownOps) {
|
||||
// `tensor` was not bufferized yet. This should never happen with
|
||||
// bufferizable ops.
|
||||
assert(!tensor.getDefiningOp<BufferizableOpInterface>() &&
|
||||
"tensor is not mapped");
|
||||
assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped");
|
||||
// Insert to_memref op.
|
||||
OpBuilder b(tensor.getContext());
|
||||
setInsertionPointAfter(b, tensor);
|
||||
|
|
|
@ -256,13 +256,13 @@ static bool aliasesNonWritableBuffer(Value value,
|
|||
aliasInfo.applyOnAliases(value, [&](Value v) {
|
||||
// Query BufferizableOpInterface to see if the OpResult is writable.
|
||||
// TODO: Out-of-place bufferized OpResult could be considered writable.
|
||||
if (auto bufferizableOp = v.getDefiningOp<BufferizableOpInterface>())
|
||||
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
|
||||
if (bufferizableOp && bufferizableOp.isWritable(v, state))
|
||||
return;
|
||||
|
||||
// Query BufferizableOpInterface to see if the BlockArgument is writable.
|
||||
if (auto bbArg = v.dyn_cast<BlockArgument>())
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(
|
||||
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
|
||||
bbArg.getOwner()->getParentOp()))
|
||||
if (bufferizableOp.isWritable(bbArg, state))
|
||||
return;
|
||||
|
@ -324,11 +324,12 @@ static bool happensBefore(Operation *a, Operation *b,
|
|||
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
|
||||
/// the result of a write W1. But because of bufferization decisions, R actually
|
||||
/// reads another write W2.
|
||||
static bool
|
||||
hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
||||
const DenseSet<OpOperand *> &usesWrite,
|
||||
const DominanceInfo &domInfo,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
static bool hasReadAfterWriteInterference(
|
||||
const DenseSet<OpOperand *> &usesRead,
|
||||
const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
|
||||
BufferizationState &state, const BufferizationAliasInfo &aliasInfo) {
|
||||
const BufferizationOptions &options = state.getOptions();
|
||||
|
||||
for (OpOperand *uRead : usesRead) {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
|
||||
|
@ -341,7 +342,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
|||
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
|
||||
// is %0. Note that operations that create an alias but do not write (such
|
||||
// as ExtractSliceOp) are skipped.
|
||||
Value lastWrite = findLastPrecedingWrite(uRead->get());
|
||||
Value lastWrite = findLastPrecedingWrite(uRead->get(), options);
|
||||
|
||||
// Look for conflicting memory writes. Potential conflicts are writes to an
|
||||
// alias that have been decided to bufferize inplace.
|
||||
|
@ -370,15 +371,15 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
|||
continue;
|
||||
|
||||
// No conflict if the op interface says so.
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
|
||||
aliasInfo))
|
||||
continue;
|
||||
|
||||
if (conflictingWritingOp != readingOp)
|
||||
if (auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
|
||||
options.dynCastBufferizableOp(conflictingWritingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
|
||||
aliasInfo))
|
||||
continue;
|
||||
|
||||
|
@ -452,7 +453,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
|||
/// involving aliases of the given OpOperand are checked.
|
||||
bool wouldCreateReadAfterWriteInterference(
|
||||
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state, const BufferizationAliasInfo &aliasInfo,
|
||||
bool checkConsistencyOnly = false) {
|
||||
#ifndef NDEBUG
|
||||
if (result) {
|
||||
|
@ -496,7 +497,8 @@ bool wouldCreateReadAfterWriteInterference(
|
|||
if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
|
||||
usesWrite.insert(&operand);
|
||||
|
||||
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
|
||||
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
|
||||
aliasInfo);
|
||||
}
|
||||
|
||||
/// Return true if bufferizing `opOperand` inplace with `opResult` would create
|
||||
|
@ -555,7 +557,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
|
|||
|
||||
bool foundInterference =
|
||||
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) ||
|
||||
wouldCreateReadAfterWriteInterference(operand, result, domInfo,
|
||||
wouldCreateReadAfterWriteInterference(operand, result, domInfo, state,
|
||||
aliasInfo);
|
||||
|
||||
if (foundInterference)
|
||||
|
@ -603,7 +605,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
|
|||
for (Operation *op : reverse(ops))
|
||||
for (OpOperand &opOperand : op->getOpOperands())
|
||||
if (opOperand.get().getType().isa<TensorType>())
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
|
||||
if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
|
||||
if (failed(bufferizableInPlaceAnalysisImpl(
|
||||
opOperand, opResult, aliasInfo, state, domInfo)))
|
||||
|
@ -633,9 +635,10 @@ static LogicalResult inPlaceAnalysis(Operation *op,
|
|||
|
||||
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
|
||||
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationOptions &options) {
|
||||
for (Operation *op : ops)
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
|
||||
for (OpResult opResult : op->getOpResults())
|
||||
if (opResult.getType().isa<TensorType>())
|
||||
if (aliasInfo.isInPlace(opResult)) {
|
||||
|
@ -652,7 +655,8 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
|
|||
/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
|
||||
/// in `op`.
|
||||
static void equivalenceAnalysis(Operation *op,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationOptions &options) {
|
||||
// Traverse ops in PostOrder: Nested ops first, then enclosing ops.
|
||||
SmallVector<Operation *> ops;
|
||||
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
|
||||
|
@ -662,21 +666,23 @@ static void equivalenceAnalysis(Operation *op,
|
|||
ops.push_back(op);
|
||||
});
|
||||
|
||||
equivalenceAnalysis(ops, aliasInfo);
|
||||
equivalenceAnalysis(ops, aliasInfo, options);
|
||||
}
|
||||
|
||||
/// Assert that the current bufferization decisions are consistent.
|
||||
static LogicalResult
|
||||
checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
|
||||
BufferizationState &state,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
const BufferizationOptions &options = state.getOptions();
|
||||
Operation *inconsistentOp = nullptr;
|
||||
WalkResult walkResult = op->walk([&](Operation *op) {
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
|
||||
for (OpOperand &opOperand : op->getOpOperands())
|
||||
if (opOperand.get().getType().isa<TensorType>()) {
|
||||
OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
|
||||
if (wouldCreateReadAfterWriteInterference(
|
||||
opOperand, opResult, domInfo, aliasInfo,
|
||||
opOperand, opResult, domInfo, state, aliasInfo,
|
||||
/*checkConsistencyOnly=*/true)) {
|
||||
// This error can happen for two reasons. Either the input IR
|
||||
// already has a read-after-write conflict. Or certain
|
||||
|
@ -723,14 +729,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
DominanceInfo domInfo(op);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
|
||||
if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo)))
|
||||
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
|
||||
return failure();
|
||||
|
||||
// If the analysis fails, just return.
|
||||
if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
|
||||
options.analysisFuzzerSeed)))
|
||||
return failure();
|
||||
equivalenceAnalysis(op, aliasInfo);
|
||||
equivalenceAnalysis(op, aliasInfo, options);
|
||||
|
||||
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
|
||||
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
|
||||
|
@ -740,7 +746,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
// Analyze ops that were created by the PostAnalysisStep.
|
||||
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
|
||||
return failure();
|
||||
equivalenceAnalysis(newOps, aliasInfo);
|
||||
equivalenceAnalysis(newOps, aliasInfo, options);
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
|
|
@ -388,6 +388,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
|
|||
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
OpBuilder b(op->getContext());
|
||||
const BufferizationOptions &options = state.getOptions();
|
||||
|
||||
WalkResult status = op->walk([&](Operation *op) {
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
|
@ -396,7 +397,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
|
|||
continue;
|
||||
|
||||
SetVector<Value> maybeInitTensor =
|
||||
findValueInReverseUseDefChain(operand.get(), [&](Value val) {
|
||||
findValueInReverseUseDefChain(operand.get(), options, [&](Value val) {
|
||||
// Continue traversal until this function returns true.
|
||||
OpResult opResult = val.dyn_cast<OpResult>();
|
||||
if (!opResult)
|
||||
|
|
|
@ -276,6 +276,7 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
|||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationOptions &options,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
|
@ -284,7 +285,7 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
|||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
||||
return llvm::all_of(findValueInReverseUseDefChain(value, options, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
|
@ -311,7 +312,7 @@ struct InsertSliceOpInterface
|
|||
}
|
||||
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
OpOperand *uConflictingWrite, BufferizationState &state,
|
||||
const BufferizationAliasInfo &aliasInfo) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
@ -328,8 +329,8 @@ struct InsertSliceOpInterface
|
|||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
|
||||
uConflictingWrite->get(), insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
|
@ -346,7 +347,8 @@ struct InsertSliceOpInterface
|
|||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
||||
hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(),
|
||||
insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
|
@ -379,8 +381,8 @@ struct InsertSliceOpInterface
|
|||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
||||
insertSliceOp))
|
||||
hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
|
||||
insertSliceOp.source(), insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
|
||||
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR
|
||||
|
||||
// CHECK-LABEL: func @use_of_unknown_op_1(
|
||||
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
|
||||
func @use_of_unknown_op_1(%t1: tensor<?xf32> {linalg.inplaceable = true})
|
||||
|
@ -148,3 +150,20 @@ func @unknown_op_not_writable(
|
|||
// CHECK: return %[[alloc]]
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-TENSOR-LABEL: func @simple_tensor_test(
|
||||
// CHECK-TENSOR-SAME: %[[t1:.*]]: tensor<?xf32>
|
||||
func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
|
||||
// CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
|
||||
// CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
|
||||
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
|
||||
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
|
||||
// CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK-TENSOR: return %[[casted_tensor]]
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -85,6 +85,10 @@ struct TestComprehensiveFunctionBufferize
|
|||
*this, "analysis-fuzzer-seed",
|
||||
llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
|
||||
llvm::cl::init(0)};
|
||||
ListOption<std::string> dialectFilter{
|
||||
*this, "dialect-filter",
|
||||
llvm::cl::desc("Bufferize only ops from the specified dialects"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -104,6 +108,12 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
|
|||
options.testAnalysisOnly = testAnalysisOnly;
|
||||
options.analysisFuzzerSeed = analysisFuzzerSeed;
|
||||
|
||||
if (dialectFilter.hasValue()) {
|
||||
options.dialectFilter.emplace();
|
||||
for (const std::string &dialectNamespace : dialectFilter)
|
||||
options.dialectFilter->insert(dialectNamespace);
|
||||
}
|
||||
|
||||
Operation *op = getFunction().getOperation();
|
||||
if (failed(runComprehensiveBufferize(op, options)))
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue