From 847710f7b77ea4e3cd43f62b5b7d920ac47405a5 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 8 Dec 2021 23:26:22 +0900 Subject: [PATCH] [mlir][linalg][bufferize] Add dialect filter to BufferizationOptions This adds a new option `dialectFilter` to BufferizationOptions. Only ops from dialects that are allow-listed in the filter are bufferized. Other ops are left unbufferized. Note: This option requires `allowUnknownOps = true`. To make use of `dialectFilter`, BufferizationOptions or BufferizationState must be passed to various helper functions. The purpose of this change is to provide a better infrastructure for partial bufferization, which will be fully activated in a subsequent change. Differential Revision: https://reviews.llvm.org/D114691 --- .../BufferizableOpInterface.h | 56 +++++++++++++++++-- .../BufferizableOpInterface.td | 1 + .../BufferizableOpInterface.cpp | 42 ++++++++++---- .../ComprehensiveBufferize.cpp | 56 ++++++++++--------- .../LinalgInterfaceImpl.cpp | 3 +- .../TensorInterfaceImpl.cpp | 16 +++--- ...omprehensive-module-bufferize-partial.mlir | 19 +++++++ .../Linalg/TestComprehensiveBufferize.cpp | 10 ++++ 8 files changed, 154 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h index 527f11d4d93e..df327aa5e243 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -30,6 +30,7 @@ namespace comprehensive_bufferize { static constexpr int64_t kBufferAlignments = 128; class BufferizationAliasInfo; +class BufferizableOpInterface; struct BufferizationOptions; class BufferizationState; struct PostAnalysisStep; @@ -92,6 +93,33 @@ struct BufferizationOptions { std::make_unique(std::forward(args)...)); } + /// Return `true` if the op is allowed to be bufferized. + bool isOpAllowed(Operation *op) const { + if (!dialectFilter.hasValue()) + return true; + return dialectFilter->contains(op->getDialect()->getNamespace()); + } + + /// Allow-list the given dialects in the dialect filter. Only ops from + /// allow-listed dialects will be bufferized. + template + void addToDialectFilter() { + // The following expands a call to addToDialectFilterImpl for each dialect + // in 'DialectTs'. This magic is necessary due to a limitation in the places + // that a parameter pack can be expanded in c++11. + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + (void)std::initializer_list{ + 0, (addToDialectFilterImpl(), 0)...}; + } + + /// Try to cast the given op to BufferizableOpInterface if the op is allow + /// listed. + BufferizableOpInterface dynCastBufferizableOp(Operation *op) const; + + /// Try to cast the given value to BufferizableOpInterface if the op is allow + /// listed. + BufferizableOpInterface dynCastBufferizableOp(Value value) const; + /// Helper functions for allocation, deallocation, memory copying. std::unique_ptr allocationFns; @@ -114,6 +142,25 @@ struct BufferizationOptions { /// Registered post analysis steps. PostAnalysisStepList postAnalysisSteps; + + /// Only bufferize ops from dialects that are allowed-listed by the filter. + /// All other ops are ignored. This option controls the scope of partial + /// bufferization. + /// + /// Note: If no filter is specified, all ops are bufferized (as long as they + /// implement BufferizableOpInterface). If a filter is specified, + /// `allowUnknownOps` should be enabled. Otherwise, bufferization would fail + /// when encountering an op that is forbidden by the filter. + Optional> dialectFilter; + +private: + /// Allow-list a dialect in the dialect filter. + template + void addToDialectFilterImpl() { + if (!dialectFilter.hasValue()) + dialectFilter.emplace(); + dialectFilter->insert(DialectT::getDialectNamespace()); + } }; /// Specify fine-grain relationship between buffers to enable more analysis. @@ -128,7 +175,8 @@ enum class BufferRelation { /// equivalence classes to support bufferization. class BufferizationAliasInfo { public: - explicit BufferizationAliasInfo(Operation *rootOp); + explicit BufferizationAliasInfo(Operation *rootOp, + const BufferizationOptions &options); // BufferizationAliasInfo should be passed as a reference. BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; @@ -265,7 +313,7 @@ bool isValueRead(Value value); /// starting the traversal from Value 1, the resulting SetVector is: /// { 2, 7, 8, 5 } llvm::SetVector -findValueInReverseUseDefChain(Value value, +findValueInReverseUseDefChain(Value value, const BufferizationOptions &options, std::function condition); /// Find the Value of the last preceding write of a given Value. @@ -276,7 +324,7 @@ findValueInReverseUseDefChain(Value value, /// /// Note: When reaching an end of the reverse SSA use-def chain, that value /// is returned regardless of whether it is a memory write or not. -Value findLastPrecedingWrite(Value value); +Value findLastPrecedingWrite(Value value, const BufferizationOptions &options); /// Dialect-specific bufferization state. Analysis/bufferization information /// that is specific to ops from a certain dialect can be stored in derived @@ -300,7 +348,7 @@ struct DialectBufferizationState { class BufferizationState { public: BufferizationState(Operation *op, const BufferizationOptions &options) - : aliasInfo(op), options(options), builder(op->getContext()) {} + : aliasInfo(op, options), options(options), builder(op->getContext()) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td index 6a35c0e3bb52..a81b52d1433f 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -266,6 +266,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*methodName=*/"isNotConflicting", /*args=*/(ins "OpOperand *":$uRead, "OpOperand *":$uWrite, + "BufferizationState &":$state, "const BufferizationAliasInfo &":$aliasInfo), /*methodBody=*/"", /*defaultImplementation=*/[{ diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp index 3f9c8979e184..ffb8a7a25c27 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -78,7 +78,8 @@ BufferizationOptions::BufferizationOptions() // BufferizationAliasInfo //===----------------------------------------------------------------------===// -BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { +BufferizationAliasInfo::BufferizationAliasInfo( + Operation *rootOp, const BufferizationOptions &options) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) if (v.getType().isa()) @@ -93,6 +94,8 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { // Set up alias sets for OpResults that must bufferize in-place. This should // be done before making any other bufferization decisions. rootOp->walk([&](BufferizableOpInterface bufferizableOp) { + if (!options.isOpAllowed(bufferizableOp)) + return WalkResult::skip(); for (OpResult opResult : bufferizableOp->getOpResults()) { if (opResult.getType().isa()) if (bufferizableOp.mustBufferizeInPlace(opResult)) { @@ -105,6 +108,7 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { markInPlace(opResult); } } + return WalkResult::advance(); }); } @@ -197,6 +201,21 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) { } } +BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: + BufferizationOptions::dynCastBufferizableOp(Operation *op) const { + if (isOpAllowed(op)) + return dyn_cast(op); + return nullptr; +} + +BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: + BufferizationOptions::dynCastBufferizableOp(Value value) const { + if (auto bufferizableOp = value.getDefiningOp()) + if (isOpAllowed(bufferizableOp.getOperation())) + return bufferizableOp; + return nullptr; +} + /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector @@ -283,7 +302,8 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) { // further. llvm::SetVector mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain( - Value value, std::function condition) { + Value value, const BufferizationOptions &options, + std::function condition) { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -296,7 +316,7 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain( OpResult opResult = value.cast(); SmallVector opOperands = getAliasingOpOperand(opResult); - if (opOperands.empty()) { + if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { result.insert(value); continue; } @@ -310,13 +330,13 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain( // Find the Value of the last preceding write of a given Value. Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite( - Value value) { + Value value, const BufferizationOptions &options) { SetVector result = - findValueInReverseUseDefChain(value, [](Value value) { + findValueInReverseUseDefChain(value, options, [&](Value value) { Operation *op = value.getDefiningOp(); if (!op) return true; - auto bufferizableOp = dyn_cast(op); + auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) return true; return bufferizableOp.isMemoryWrite(value.cast()); @@ -374,9 +394,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState:: // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA // use-def chain, it returns that value, regardless of whether it is a // memory write or not. - Value lastWrite = findLastPrecedingWrite(operand); - if (auto bufferizableOp = - lastWrite.getDefiningOp()) + Value lastWrite = findLastPrecedingWrite(operand, options); + if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) if (!bufferizableOp.isMemoryWrite(lastWrite.cast())) skipCopy = true; // Do not copy if the copied data is never read. @@ -433,7 +452,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op, // Bufferize using `BufferizableOpInterface`. Interface implementations are // responsible for bufferizing nested ops. - if (auto bufferizableOp = dyn_cast(op)) { + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { b.setInsertionPoint(op); return bufferizableOp.bufferize(b, state); } @@ -640,8 +659,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( if (options.allowUnknownOps) { // `tensor` was not bufferized yet. This should never happen with // bufferizable ops. - assert(!tensor.getDefiningOp() && - "tensor is not mapped"); + assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped"); // Insert to_memref op. OpBuilder b(tensor.getContext()); setInsertionPointAfter(b, tensor); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp index 7f2ae60b1309..6cfae7cc702e 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -256,13 +256,13 @@ static bool aliasesNonWritableBuffer(Value value, aliasInfo.applyOnAliases(value, [&](Value v) { // Query BufferizableOpInterface to see if the OpResult is writable. // TODO: Out-of-place bufferized OpResult could be considered writable. - if (auto bufferizableOp = v.getDefiningOp()) + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v)) if (bufferizableOp && bufferizableOp.isWritable(v, state)) return; // Query BufferizableOpInterface to see if the BlockArgument is writable. if (auto bbArg = v.dyn_cast()) - if (auto bufferizableOp = dyn_cast( + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp( bbArg.getOwner()->getParentOp())) if (bufferizableOp.isWritable(bbArg, state)) return; @@ -324,11 +324,12 @@ static bool happensBefore(Operation *a, Operation *b, /// A conflict is: According to SSA use-def chains, a read R is supposed to read /// the result of a write W1. But because of bufferization decisions, R actually /// reads another write W2. -static bool -hasReadAfterWriteInterference(const DenseSet &usesRead, - const DenseSet &usesWrite, - const DominanceInfo &domInfo, - const BufferizationAliasInfo &aliasInfo) { +static bool hasReadAfterWriteInterference( + const DenseSet &usesRead, + const DenseSet &usesWrite, const DominanceInfo &domInfo, + BufferizationState &state, const BufferizationAliasInfo &aliasInfo) { + const BufferizationOptions &options = state.getOptions(); + for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -341,7 +342,7 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, // In the above example, if uRead is the OpOperand of reading_op, lastWrite // is %0. Note that operations that create an alias but do not write (such // as ExtractSliceOp) are skipped. - Value lastWrite = findLastPrecedingWrite(uRead->get()); + Value lastWrite = findLastPrecedingWrite(uRead->get(), options); // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. @@ -370,15 +371,15 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, continue; // No conflict if the op interface says so. - if (auto bufferizableOp = dyn_cast(readingOp)) - if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state, aliasInfo)) continue; if (conflictingWritingOp != readingOp) if (auto bufferizableOp = - dyn_cast(conflictingWritingOp)) - if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + options.dynCastBufferizableOp(conflictingWritingOp)) + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state, aliasInfo)) continue; @@ -452,7 +453,7 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, /// involving aliases of the given OpOperand are checked. bool wouldCreateReadAfterWriteInterference( OpOperand &operand, OpResult result, const DominanceInfo &domInfo, - const BufferizationAliasInfo &aliasInfo, + BufferizationState &state, const BufferizationAliasInfo &aliasInfo, bool checkConsistencyOnly = false) { #ifndef NDEBUG if (result) { @@ -496,7 +497,8 @@ bool wouldCreateReadAfterWriteInterference( if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); - return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo); + return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, + aliasInfo); } /// Return true if bufferizing `opOperand` inplace with `opResult` would create @@ -555,7 +557,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl( bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) || - wouldCreateReadAfterWriteInterference(operand, result, domInfo, + wouldCreateReadAfterWriteInterference(operand, result, domInfo, state, aliasInfo); if (foundInterference) @@ -603,7 +605,7 @@ static LogicalResult inPlaceAnalysis(SmallVector &ops, for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) - if (auto bufferizableOp = dyn_cast(op)) + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) if (failed(bufferizableInPlaceAnalysisImpl( opOperand, opResult, aliasInfo, state, domInfo))) @@ -633,9 +635,10 @@ static LogicalResult inPlaceAnalysis(Operation *op, /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. static void equivalenceAnalysis(SmallVector &ops, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + const BufferizationOptions &options) { for (Operation *op : ops) - if (auto bufferizableOp = dyn_cast(op)) + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) for (OpResult opResult : op->getOpResults()) if (opResult.getType().isa()) if (aliasInfo.isInPlace(opResult)) { @@ -652,7 +655,8 @@ static void equivalenceAnalysis(SmallVector &ops, /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained /// in `op`. static void equivalenceAnalysis(Operation *op, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + const BufferizationOptions &options) { // Traverse ops in PostOrder: Nested ops first, then enclosing ops. SmallVector ops; op->walk([&](Operation *op) { @@ -662,21 +666,23 @@ static void equivalenceAnalysis(Operation *op, ops.push_back(op); }); - equivalenceAnalysis(ops, aliasInfo); + equivalenceAnalysis(ops, aliasInfo, options); } /// Assert that the current bufferization decisions are consistent. static LogicalResult checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, + BufferizationState &state, const BufferizationAliasInfo &aliasInfo) { + const BufferizationOptions &options = state.getOptions(); Operation *inconsistentOp = nullptr; WalkResult walkResult = op->walk([&](Operation *op) { - if (auto bufferizableOp = dyn_cast(op)) + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) { OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand); if (wouldCreateReadAfterWriteInterference( - opOperand, opResult, domInfo, aliasInfo, + opOperand, opResult, domInfo, state, aliasInfo, /*checkConsistencyOnly=*/true)) { // This error can happen for two reasons. Either the input IR // already has a read-after-write conflict. Or certain @@ -723,14 +729,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo))) + if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); // If the analysis fails, just return. if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, options.analysisFuzzerSeed))) return failure(); - equivalenceAnalysis(op, aliasInfo); + equivalenceAnalysis(op, aliasInfo, options); auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { for (const std::unique_ptr &step : steps) { @@ -740,7 +746,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) return failure(); - equivalenceAnalysis(newOps, aliasInfo); + equivalenceAnalysis(newOps, aliasInfo, options); } return success(); }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp index b68f1a2da516..3ac95dbe18c4 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -388,6 +388,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: std::function rewriteFunc, SmallVector &newOps) { OpBuilder b(op->getContext()); + const BufferizationOptions &options = state.getOptions(); WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { @@ -396,7 +397,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: continue; SetVector maybeInitTensor = - findValueInReverseUseDefChain(operand.get(), [&](Value val) { + findValueInReverseUseDefChain(operand.get(), options, [&](Value val) { // Continue traversal until this function returns true. OpResult opResult = val.dyn_cast(); if (!opResult) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp index ca38d27e121e..cfc04be793b1 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -276,6 +276,7 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, + const BufferizationOptions &options, Value value, InsertSliceOp insertOp) { auto condition = [&](Value val) { if (auto extractOp = val.getDefiningOp()) @@ -284,7 +285,7 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, return false; }; - return llvm::all_of(findValueInReverseUseDefChain(value, condition), + return llvm::all_of(findValueInReverseUseDefChain(value, options, condition), condition); } @@ -311,7 +312,7 @@ struct InsertSliceOpInterface } bool isNotConflicting(Operation *op, OpOperand *uRead, - OpOperand *uConflictingWrite, + OpOperand *uConflictingWrite, BufferizationState &state, const BufferizationAliasInfo &aliasInfo) const { Operation *readingOp = uRead->getOwner(); Operation *conflictingWritingOp = uConflictingWrite->getOwner(); @@ -328,8 +329,8 @@ struct InsertSliceOpInterface // TODO: Use insertSliceOp.getDestOpOperand etc. when available. if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), - insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), + uConflictingWrite->get(), insertSliceOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -346,7 +347,8 @@ struct InsertSliceOpInterface if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(), + insertSliceOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -379,8 +381,8 @@ struct InsertSliceOpInterface if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && aliasInfo.areEquivalentBufferizedValues(uRead->get(), insertSliceOp.source()) && - hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(), - insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), + insertSliceOp.source(), insertSliceOp)) return true; return false; diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir index 6da6b2a514dc..2870d40f076c 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -8,6 +8,8 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR + // CHECK-LABEL: func @use_of_unknown_op_1( // CHECK-SAME: %[[m1:.*]]: memref {linalg.inplaceable = true}) @@ -148,3 +150,20 @@ func @unknown_op_not_writable( // CHECK: return %[[alloc]] return %1 : tensor } + +// ----- + +// CHECK-TENSOR-LABEL: func @simple_tensor_test( +// CHECK-TENSOR-SAME: %[[t1:.*]]: tensor +func @simple_tensor_test(%t1 : tensor, %f : f32) -> tensor { + // CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + %c0 = arith.constant 0 : index + // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc + // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]] + %0 = tensor.insert %f into %t1[%c0] : tensor + // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK-TENSOR: return %[[casted_tensor]] + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp index 5ac15a979d00..fae27fc1a3f4 100644 --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -85,6 +85,10 @@ struct TestComprehensiveFunctionBufferize *this, "analysis-fuzzer-seed", llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"), llvm::cl::init(0)}; + ListOption dialectFilter{ + *this, "dialect-filter", + llvm::cl::desc("Bufferize only ops from the specified dialects"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace @@ -104,6 +108,12 @@ void TestComprehensiveFunctionBufferize::runOnFunction() { options.testAnalysisOnly = testAnalysisOnly; options.analysisFuzzerSeed = analysisFuzzerSeed; + if (dialectFilter.hasValue()) { + options.dialectFilter.emplace(); + for (const std::string &dialectNamespace : dialectFilter) + options.dialectFilter->insert(dialectNamespace); + } + Operation *op = getFunction().getOperation(); if (failed(runComprehensiveBufferize(op, options))) return;