[mlir][linalg][bufferize] Add dialect filter to BufferizationOptions

This adds a new option `dialectFilter` to BufferizationOptions. Only ops from dialects that are allow-listed in the filter are bufferized. Other ops are left unbufferized. Note: This option requires `allowUnknownOps = true`.

To make use of `dialectFilter`, BufferizationOptions or BufferizationState must be passed to various helper functions.

The purpose of this change is to provide a better infrastructure for partial bufferization, which will be fully activated in a subsequent change.

Differential Revision: https://reviews.llvm.org/D114691
This commit is contained in:
Matthias Springer 2021-12-08 23:26:22 +09:00
parent 84687405ce
commit 847710f7b7
8 changed files with 154 additions and 49 deletions

View File

@ -30,6 +30,7 @@ namespace comprehensive_bufferize {
static constexpr int64_t kBufferAlignments = 128;
class BufferizationAliasInfo;
class BufferizableOpInterface;
struct BufferizationOptions;
class BufferizationState;
struct PostAnalysisStep;
@ -92,6 +93,33 @@ struct BufferizationOptions {
std::make_unique<Step>(std::forward<Args>(args)...));
}
/// Return `true` if the op is allowed to be bufferized.
bool isOpAllowed(Operation *op) const {
if (!dialectFilter.hasValue())
return true;
return dialectFilter->contains(op->getDialect()->getNamespace());
}
/// Allow-list the given dialects in the dialect filter. Only ops from
/// allow-listed dialects will be bufferized.
template <typename... DialectTs>
void addToDialectFilter() {
// The following expands a call to addToDialectFilterImpl for each dialect
// in 'DialectTs'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0, (addToDialectFilterImpl<DialectTs>(), 0)...};
}
/// Try to cast the given op to BufferizableOpInterface if the op is allow
/// listed.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
/// Try to cast the given value to BufferizableOpInterface if the op is allow
/// listed.
BufferizableOpInterface dynCastBufferizableOp(Value value) const;
/// Helper functions for allocation, deallocation, memory copying.
std::unique_ptr<AllocationCallbacks> allocationFns;
@ -114,6 +142,25 @@ struct BufferizationOptions {
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
/// Only bufferize ops from dialects that are allowed-listed by the filter.
/// All other ops are ignored. This option controls the scope of partial
/// bufferization.
///
/// Note: If no filter is specified, all ops are bufferized (as long as they
/// implement BufferizableOpInterface). If a filter is specified,
/// `allowUnknownOps` should be enabled. Otherwise, bufferization would fail
/// when encountering an op that is forbidden by the filter.
Optional<DenseSet<StringRef>> dialectFilter;
private:
/// Allow-list a dialect in the dialect filter.
template <typename DialectT>
void addToDialectFilterImpl() {
if (!dialectFilter.hasValue())
dialectFilter.emplace();
dialectFilter->insert(DialectT::getDialectNamespace());
}
};
/// Specify fine-grain relationship between buffers to enable more analysis.
@ -128,7 +175,8 @@ enum class BufferRelation {
/// equivalence classes to support bufferization.
class BufferizationAliasInfo {
public:
explicit BufferizationAliasInfo(Operation *rootOp);
explicit BufferizationAliasInfo(Operation *rootOp,
const BufferizationOptions &options);
// BufferizationAliasInfo should be passed as a reference.
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
@ -265,7 +313,7 @@ bool isValueRead(Value value);
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
llvm::SetVector<Value>
findValueInReverseUseDefChain(Value value,
findValueInReverseUseDefChain(Value value, const BufferizationOptions &options,
std::function<bool(Value)> condition);
/// Find the Value of the last preceding write of a given Value.
@ -276,7 +324,7 @@ findValueInReverseUseDefChain(Value value,
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
/// is returned regardless of whether it is a memory write or not.
Value findLastPrecedingWrite(Value value);
Value findLastPrecedingWrite(Value value, const BufferizationOptions &options);
/// Dialect-specific bufferization state. Analysis/bufferization information
/// that is specific to ops from a certain dialect can be stored in derived
@ -300,7 +348,7 @@ struct DialectBufferizationState {
class BufferizationState {
public:
BufferizationState(Operation *op, const BufferizationOptions &options)
: aliasInfo(op), options(options), builder(op->getContext()) {}
: aliasInfo(op, options), options(options), builder(op->getContext()) {}
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;

View File

@ -266,6 +266,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"isNotConflicting",
/*args=*/(ins "OpOperand *":$uRead,
"OpOperand *":$uWrite,
"BufferizationState &":$state,
"const BufferizationAliasInfo &":$aliasInfo),
/*methodBody=*/"",
/*defaultImplementation=*/[{

View File

@ -78,7 +78,8 @@ BufferizationOptions::BufferizationOptions()
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
BufferizationAliasInfo::BufferizationAliasInfo(
Operation *rootOp, const BufferizationOptions &options) {
rootOp->walk([&](Operation *op) {
for (Value v : op->getResults())
if (v.getType().isa<TensorType>())
@ -93,6 +94,8 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpResult opResult : bufferizableOp->getOpResults()) {
if (opResult.getType().isa<TensorType>())
if (bufferizableOp.mustBufferizeInPlace(opResult)) {
@ -105,6 +108,7 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
markInPlace(opResult);
}
}
return WalkResult::advance();
});
}
@ -197,6 +201,21 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
}
}
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
if (isOpAllowed(op))
return dyn_cast<BufferizableOpInterface>(op);
return nullptr;
}
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Value value) const {
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
if (isOpAllowed(bufferizableOp.getOperation()))
return bufferizableOp;
return nullptr;
}
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
@ -283,7 +302,8 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
// further.
llvm::SetVector<Value>
mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
Value value, std::function<bool(Value)> condition) {
Value value, const BufferizationOptions &options,
std::function<bool(Value)> condition) {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@ -296,7 +316,7 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
OpResult opResult = value.cast<OpResult>();
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
if (opOperands.empty()) {
if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
result.insert(value);
continue;
}
@ -310,13 +330,13 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
// Find the Value of the last preceding write of a given Value.
Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
Value value) {
Value value, const BufferizationOptions &options) {
SetVector<Value> result =
findValueInReverseUseDefChain(value, [](Value value) {
findValueInReverseUseDefChain(value, options, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return true;
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
return true;
return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
@ -374,9 +394,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
Value lastWrite = findLastPrecedingWrite(operand);
if (auto bufferizableOp =
lastWrite.getDefiningOp<BufferizableOpInterface>())
Value lastWrite = findLastPrecedingWrite(operand, options);
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
skipCopy = true;
// Do not copy if the copied data is never read.
@ -433,7 +452,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
// Bufferize using `BufferizableOpInterface`. Interface implementations are
// responsible for bufferizing nested ops.
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
b.setInsertionPoint(op);
return bufferizableOp.bufferize(b, state);
}
@ -640,8 +659,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
if (options.allowUnknownOps) {
// `tensor` was not bufferized yet. This should never happen with
// bufferizable ops.
assert(!tensor.getDefiningOp<BufferizableOpInterface>() &&
"tensor is not mapped");
assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped");
// Insert to_memref op.
OpBuilder b(tensor.getContext());
setInsertionPointAfter(b, tensor);

View File

@ -256,13 +256,13 @@ static bool aliasesNonWritableBuffer(Value value,
aliasInfo.applyOnAliases(value, [&](Value v) {
// Query BufferizableOpInterface to see if the OpResult is writable.
// TODO: Out-of-place bufferized OpResult could be considered writable.
if (auto bufferizableOp = v.getDefiningOp<BufferizableOpInterface>())
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
if (bufferizableOp && bufferizableOp.isWritable(v, state))
return;
// Query BufferizableOpInterface to see if the BlockArgument is writable.
if (auto bbArg = v.dyn_cast<BlockArgument>())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
bbArg.getOwner()->getParentOp()))
if (bufferizableOp.isWritable(bbArg, state))
return;
@ -324,11 +324,12 @@ static bool happensBefore(Operation *a, Operation *b,
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
/// the result of a write W1. But because of bufferization decisions, R actually
/// reads another write W2.
static bool
hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
static bool hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
BufferizationState &state, const BufferizationAliasInfo &aliasInfo) {
const BufferizationOptions &options = state.getOptions();
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@ -341,7 +342,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
// is %0. Note that operations that create an alias but do not write (such
// as ExtractSliceOp) are skipped.
Value lastWrite = findLastPrecedingWrite(uRead->get());
Value lastWrite = findLastPrecedingWrite(uRead->get(), options);
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
@ -370,15 +371,15 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
continue;
// No conflict if the op interface says so.
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
aliasInfo))
continue;
if (conflictingWritingOp != readingOp)
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
options.dynCastBufferizableOp(conflictingWritingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
aliasInfo))
continue;
@ -452,7 +453,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
/// involving aliases of the given OpOperand are checked.
bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo,
BufferizationState &state, const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
#ifndef NDEBUG
if (result) {
@ -496,7 +497,8 @@ bool wouldCreateReadAfterWriteInterference(
if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
aliasInfo);
}
/// Return true if bufferizing `opOperand` inplace with `opResult` would create
@ -555,7 +557,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) ||
wouldCreateReadAfterWriteInterference(operand, result, domInfo,
wouldCreateReadAfterWriteInterference(operand, result, domInfo, state,
aliasInfo);
if (foundInterference)
@ -603,7 +605,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
for (Operation *op : reverse(ops))
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
if (failed(bufferizableInPlaceAnalysisImpl(
opOperand, opResult, aliasInfo, state, domInfo)))
@ -633,9 +635,10 @@ static LogicalResult inPlaceAnalysis(Operation *op,
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
const BufferizationOptions &options) {
for (Operation *op : ops)
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
for (OpResult opResult : op->getOpResults())
if (opResult.getType().isa<TensorType>())
if (aliasInfo.isInPlace(opResult)) {
@ -652,7 +655,8 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
/// in `op`.
static void equivalenceAnalysis(Operation *op,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
const BufferizationOptions &options) {
// Traverse ops in PostOrder: Nested ops first, then enclosing ops.
SmallVector<Operation *> ops;
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@ -662,21 +666,23 @@ static void equivalenceAnalysis(Operation *op,
ops.push_back(op);
});
equivalenceAnalysis(ops, aliasInfo);
equivalenceAnalysis(ops, aliasInfo, options);
}
/// Assert that the current bufferization decisions are consistent.
static LogicalResult
checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
BufferizationState &state,
const BufferizationAliasInfo &aliasInfo) {
const BufferizationOptions &options = state.getOptions();
Operation *inconsistentOp = nullptr;
WalkResult walkResult = op->walk([&](Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>()) {
OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
if (wouldCreateReadAfterWriteInterference(
opOperand, opResult, domInfo, aliasInfo,
opOperand, opResult, domInfo, state, aliasInfo,
/*checkConsistencyOnly=*/true)) {
// This error can happen for two reasons. Either the input IR
// already has a read-after-write conflict. Or certain
@ -723,14 +729,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo)))
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
return failure();
// If the analysis fails, just return.
if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
options.analysisFuzzerSeed)))
return failure();
equivalenceAnalysis(op, aliasInfo);
equivalenceAnalysis(op, aliasInfo, options);
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
@ -740,7 +746,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Analyze ops that were created by the PostAnalysisStep.
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
return failure();
equivalenceAnalysis(newOps, aliasInfo);
equivalenceAnalysis(newOps, aliasInfo, options);
}
return success();
};

View File

@ -388,6 +388,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps) {
OpBuilder b(op->getContext());
const BufferizationOptions &options = state.getOptions();
WalkResult status = op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
@ -396,7 +397,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
continue;
SetVector<Value> maybeInitTensor =
findValueInReverseUseDefChain(operand.get(), [&](Value val) {
findValueInReverseUseDefChain(operand.get(), options, [&](Value val) {
// Continue traversal until this function returns true.
OpResult opResult = val.dyn_cast<OpResult>();
if (!opResult)

View File

@ -276,6 +276,7 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
const BufferizationOptions &options,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
@ -284,7 +285,7 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
return false;
};
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
return llvm::all_of(findValueInReverseUseDefChain(value, options, condition),
condition);
}
@ -311,7 +312,7 @@ struct InsertSliceOpInterface
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
OpOperand *uConflictingWrite, BufferizationState &state,
const BufferizationAliasInfo &aliasInfo) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@ -328,8 +329,8 @@ struct InsertSliceOpInterface
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
insertSliceOp))
hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
uConflictingWrite->get(), insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@ -346,7 +347,8 @@ struct InsertSliceOpInterface
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(),
insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@ -379,8 +381,8 @@ struct InsertSliceOpInterface
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
insertSliceOp))
hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
insertSliceOp.source(), insertSliceOp))
return true;
return false;

View File

@ -8,6 +8,8 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
// RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR
// CHECK-LABEL: func @use_of_unknown_op_1(
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
func @use_of_unknown_op_1(%t1: tensor<?xf32> {linalg.inplaceable = true})
@ -148,3 +150,20 @@ func @unknown_op_not_writable(
// CHECK: return %[[alloc]]
return %1 : tensor<?xf32>
}
// -----
// CHECK-TENSOR-LABEL: func @simple_tensor_test(
// CHECK-TENSOR-SAME: %[[t1:.*]]: tensor<?xf32>
func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
// CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
%c0 = arith.constant 0 : index
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
// CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
// CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
// CHECK-TENSOR: return %[[casted_tensor]]
return %0 : tensor<?xf32>
}

View File

@ -85,6 +85,10 @@ struct TestComprehensiveFunctionBufferize
*this, "analysis-fuzzer-seed",
llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
llvm::cl::init(0)};
ListOption<std::string> dialectFilter{
*this, "dialect-filter",
llvm::cl::desc("Bufferize only ops from the specified dialects"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace
@ -104,6 +108,12 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
options.testAnalysisOnly = testAnalysisOnly;
options.analysisFuzzerSeed = analysisFuzzerSeed;
if (dialectFilter.hasValue()) {
options.dialectFilter.emplace();
for (const std::string &dialectNamespace : dialectFilter)
options.dialectFilter->insert(dialectNamespace);
}
Operation *op = getFunction().getOperation();
if (failed(runComprehensiveBufferize(op, options)))
return;